# STABLE DIFFUSION IMPLEMENTATION USING CELEBA DATASET 

### SETUP

In [None]:
import torch
import functools
from tqdm import tqdm, trange
import torch.multiprocessing
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
torch.multiprocessing.set_sharing_strategy('file_system')
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import CelebA
from torchvision.transforms import ToTensor, CenterCrop, Resize, Compose, Normalize
import math
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from StableDiff_UNet_model import UNet_SD, load_pipe_into_our_UNet

### TRANSFORMATIONS SETUP AND DATASET LOADING

In [None]:
tfm = Compose( #initializes a sequence of transformations.
    Resize(32),
    CenterCrop(32),
    ToTensor(),
    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
dataset_rsz = CelebA("ttoisd", target_type=["attr"], transform=tfm, download=True)#Initializing the Dataset with Transformation

### DATA PREPARATION FOR TRAINING/EVALUATION

In [None]:
dataloader = DataLoader(dataset_rsz, batch_size=64, num_workers=8, shuffle=False)
x_col = []
y_col = []
for xs, ys in tqdm(dataloader):#Iterating Through DataLoader and Collecting Data
  x_col.append(xs)
  y_col.append(ys)
x_col = torch.concat(x_col, dim=0)
y_col = torch.concat(y_col, dim=0)
print(x_col.shape)
print(y_col.shape)

nantoken = 40 #placeholder or filler in the sequence tensor yseq_data.
maxlen = (y_col.sum(dim=1)).max()
yseq_data = torch.ones(y_col.size(0), maxlen, dtype=int).fill_(nantoken)

saved_dataset = TensorDataset(x_col, yseq_data)

### CALCULATION-STANDRAD DEVIATION & DIFFUSION COEFFICIENT

In [None]:
device = 'cuda'

def marginal_prob_std(t, sigma):#calculates a standard deviation related to the marginal probability.
    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / math.log(sigma))

def diffusion_coeff(t, sigma): #calculates a diffusion coefficient.
    return torch.tensor(sigma ** t, device=device)

sigma = 25.0  # @param {'type':'number'}-> sigma is a parameter that can be modified or adjusted.
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

### TRAINING LOSS FUNCTION

In [None]:
def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5):
    """The loss function for training score-based generative models.

    Args:
    model: A PyTorch model instance that represents a
      time-dependent score-based model.
    x: A mini-batch of training data.
    marginal_prob_std: A function that gives the standard deviation of
      the perturbation kernel.
    eps: A tolerance value for numerical stability.
    """
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps  #Creates a tensor of random numbers between 0 and 1, with the same number of elements as there are samples in the mini-batch x
    z = torch.randn_like(x)              #Generates Gaussian (normal) noise with the same shape as x
    std = marginal_prob_std(random_t) #Calculates the standard deviation for each sample in the batch 
    perturbed_x = x + z * std[:, None, None, None]   #Adds noise to the data x where z is scaled by the computed standard deviations std
    score = model(perturbed_x, random_t, cond=y, output_dict=False)
    loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
    return loss

### SAMPLING USING EULER-MARUYAMA METHOD

In [None]:
def Euler_Maruyama_sampler(score_model,  
                           marginal_prob_std,
                           diffusion_coeff,
                           batch_size=64,
                           x_shape=(1, 28, 28),
                           num_steps=500,
                           device='cuda',
                           eps=1e-3,
                           y=None):
    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, *x_shape, device=device) \
             * marginal_prob_std(t)[:, None, None, None]
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]
    x = init_x
    with torch.no_grad():
        for time_step in tqdm(time_steps):
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g ** 2)[:, None, None, None] * score_model(x, batch_time_step, cond=y, output_dict=False) * step_size
            x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
            # Do not include any noise in the last sampling step.
    return mean_x

### TRAINING A SCORE-BASED GENERATIVE MODEL  

In [None]:
def train_score_model(score_model, cond_embed, dataset, lr, n_epochs, batch_size, ckpt_name,
                      marginal_prob_std_fn=marginal_prob_std_fn,
                      lr_scheduler_fn=lambda epoch: max(0.2, 0.98 ** epoch),
                      device="cuda",
                      callback=None): # resume=False,
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    optimizer = Adam([*score_model.parameters(), *cond_embed.parameters()], lr=lr)
    scheduler = LambdaLR(optimizer, lr_lambda=lr_scheduler_fn)
    tqdm_epoch = trange(n_epochs)
    for epoch in tqdm_epoch: #Loop over each epoch with a progress bar
        score_model.train()
        avg_loss = 0.
        num_items = 0
        batch_tqdm = tqdm(data_loader)
        for x, y in batch_tqdm:
            x = x.to(device)
            y_emb = cond_embed(y.to(device))  #Moves the conditional data to the device and embeds it using cond_embed.
            loss = loss_fn_cond(score_model, x, y_emb, marginal_prob_std_fn)
            optimizer.zero_grad()
            loss.backward()  #Computes gradients of the loss with respect to model parameters.
            optimizer.step()
            avg_loss += loss.item() * x.shape[0]
            num_items += x.shape[0]
            batch_tqdm.set_description("Epoch %d, loss %.4f" % (epoch, avg_loss / num_items))
        scheduler.step()
        lr_current = scheduler.get_last_lr()[0]
        print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))
        # Print the averaged training loss so far.
        tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
        # Update the checkpoint after each epoch of training.
        torch.save(score_model.state_dict(), f'fort2/ckpt_{ckpt_name}.pth')
        torch.save(cond_embed.state_dict(),
                   f'fort2/ckpt_{ckpt_name}_cond_embed.pth')
        if callback is not None:
            score_model.eval()
            callback(score_model, epoch, ckpt_name)

### GENERATING AND SAVING SAMPLES

In [None]:
def save_sample_callback(score_model, epocs, ckpt_name):
    sample_batch_size = 64
    num_steps = 250
    y_samp = yseq_data[:sample_batch_size, :]
    y_emb = cond_embed(y_samp.cuda())
    sampler = Euler_Maruyama_sampler
    samples = sampler(score_model,
                      marginal_prob_std_fn,
                      diffusion_coeff_fn,
                      sample_batch_size,
                      x_shape=(3, 32, 32),
                      num_steps=num_steps,
                      device=device,
                      y=y_emb, )
    denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225],
                        [1/0.229, 1/0.224, 1/0.225])
    samples = denormalize(samples).clamp(0.0, 1.0)
    sample_grid = make_grid(samples, nrow=int(math.sqrt(sample_batch_size)))

    plt.figure(figsize=(8, 8))
    plt.axis('off')
    plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
    plt.tight_layout()
    plt.savefig(f"fort2/samples_{ckpt_name}_{epocs}.png")
    plt.show()

### SETTING UP A STABLE DIFFUSION U-NET MODEL

In [None]:
unet_face = UNet_SD(in_channels=3,
                    base_channels=128,
                    time_emb_dim=256,
                    context_dim=256,
                    multipliers=(1, 1, 2),
                    attn_levels=(1, 2, ),
                    nResAttn_block=1,
                    )
cond_embed = nn.Embedding(40 + 1, 256, padding_idx=40).cuda()

### MODEL TRAINING AND SAVING

In [None]:
torch.save(unet_face.state_dict(), "fort2/SD_unet_face.pt",)
unet_face(torch.randn(1, 3, 64, 64).cuda(), time_steps=torch.rand(1).cuda(),
          cond=torch.randn(1, 20, 256).cuda(),
          output_dict=False)

train_score_model(unet_face, cond_embed, saved_dataset,
                  lr=1.5e-4, n_epochs=100, batch_size=256,
                  ckpt_name="unet_SD_face", device=device,
                  callback=save_sample_callback)

save_sample_callback(unet_face, 0, "unet_SD_face")
torch.save(cond_embed.state_dict(), f'fort2/ckpt_{"unet_SD_face"}_cond_embed.pth')

###  DEMONSTRATION OF HOW TO MANIPULATE IMAGE ATTRIBUTES USING A TRAINED MODEL AND GENERATE NEW IMAGES WITH MODIFIED ATTRIBUTES

In [None]:
import torch
import functools
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import math
from tqdm import tqdm, trange
import torch.multiprocessing
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import CelebA
from torchvision.transforms import ToTensor, CenterCrop, Resize, Compose, Normalize

In [None]:
# Functions to calculate marginal probability and diffusion coefficient
def marginal_prob_std(t, sigma):
    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / math.log(sigma))

def diffusion_coeff(t, sigma):
    return torch.tensor(sigma ** t, device=device)

sigma = 25.0
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

In [None]:
# Euler-Maruyama sampler for generating samples
def Euler_Maruyama_sampler(score_model, marginal_prob_std, diffusion_coeff, batch_size=64, x_shape=(3, 32, 32), num_steps=500, device='cuda', eps=1e-3, y=None):
    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, *x_shape, device=device) * marginal_prob_std(t)[:, None, None, None]
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]
    x = init_x
    with torch.no_grad():
        for time_step in tqdm(time_steps):
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g ** 2)[:, None, None, None] * score_model(x, batch_time_step, cond=y, output_dict=False) * step_size
            x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
    return mean_x

In [None]:
# Load the trained model and embedding
device = 'cuda'
unet_face = UNet_SD(in_channels=3, base_channels=128, time_emb_dim=256, context_dim=256, multipliers=(1, 1, 2), attn_levels=(1, 2), nResAttn_block=1).to(device)
cond_embed = nn.Embedding(40 + 1, 256, padding_idx=40).to(device)

unet_face.load_state_dict(torch.load('fort2/ckpt_unet_SD_face.pth'))
cond_embed.load_state_dict(torch.load('fort2/ckpt_unet_SD_face_cond_embed.pth'))

In [None]:
# Define transformations and load CelebA dataset
tfm = Compose([Resize(32), CenterCrop(32), ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
dataset = CelebA("ttoisd", target_type=["attr"], transform=tfm, download=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

In [None]:
# Get a single image and its attributes
single_image, single_attr = next(iter(dataloader))
single_image = single_image.to(device)
single_attr = single_attr.to(device)

In [None]:
# Define specific labels for the single image (modify as desired)
desired_labels = single_attr.clone()
#desired_labels[0, 31] = 0  # Set 'smiling' attribute to False
desired_labels[0, 15] = 1# Set 'wearing glasses' attribute to True
#desired_labels[0, 8] = 1 # black hair
#desired_labels[0, 35] = 1 # wearing hat
#desired_labels[0, 39] = 1 # wearing hat

In [None]:
# Embed the labels
y_emb = cond_embed(desired_labels.to(device))

In [None]:
# Function to generate and visualize sample images
def generate_and_visualize_samples(model, y_emb, original_image, device='cuda', num_samples=1, num_steps=250):
    samples = Euler_Maruyama_sampler(model, marginal_prob_std_fn, diffusion_coeff_fn, num_samples, x_shape=(3, 32, 32), num_steps=num_steps, device=device, y=y_emb)

    # Denormalize the samples
    denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225], [1/0.229, 1/0.224, 1/0.225])
    samples = denormalize(samples).clamp(0.0, 1.0)
    
    # Visualize the original and generated samples
    plt.figure(figsize=(8, 8))
    
    # Original image
    plt.subplot(1, 2, 1)
    plt.title('Original Image')
    original_image_denorm = denormalize(original_image).clamp(0.0, 1.0)
    plt.imshow(original_image_denorm[0].permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
    plt.axis('off')
    
    # Generated image
    plt.subplot(1, 2, 2)
    plt.title('Generated Image')
    plt.imshow(samples[0].permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Generate and visualize samples with the desired labels
generate_and_visualize_samples(unet_face, y_emb, single_image)


In [None]:
'''1.5_o_Clock_Shadow,2.Arched_Eyebrows,3.Attractive,4.Bags_Under_Eyes,5.Bald,6.Bangs,7.Big_Lips,8.Big_Nose,9.Black_Hair,10.Blond_Hair,11.Blurry,12.Brown_Hair
13.Bushy_Eyebrows,14.Chubby,15.Double_Chin,16.Eyeglasses,17.Goatee,18.Gray_Hair,19.Heavy_Makeup,20.High_Cheekbones,21.Male,22.Mouth_Slightly_Open
23.Mustache,24.Narrow_Eyes,25.No_Beard,26.Oval_Face,27.Pale_Skin,28.Pointy_Nose,29.Receding_Hairline,30.Rosy_Cheeks,31.Sideburns,32.Smiling
33.Straight_Hair,34.Wavy_Hair,35.Wearing_Earrings,36.Wearing_Hat,37.Wearing_Lipstick,38.Wearing_Necklace,39.Wearing_Necktie,40.Young

In [None]:
'''import torch
import functools
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import math
from tqdm import tqdm, trange
import torch.multiprocessing
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from StableDiff_UNet_model import UNet_SD, load_pipe_into_our_UNet
from torch.multiprocessing import set_sharing_strategy
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import CelebA
from torchvision.transforms import ToTensor, CenterCrop, Resize, Compose, Normalize

def marginal_prob_std(t, sigma):
    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / math.log(sigma))
def diffusion_coeff(t, sigma):
    return torch.tensor(sigma ** t, device=device)


sigma = 25.0  # @param {'type':'number'}-> sigma is a parameter that can be modified or adjusted.
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
    
def Euler_Maruyama_sampler(score_model,
                           marginal_prob_std,
                           diffusion_coeff,
                           batch_size=64,
                           x_shape=(1, 28, 28),
                           num_steps=500,
                           device='cuda',
                           eps=1e-3,
                           y=None):
    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, *x_shape, device=device) \
             * marginal_prob_std(t)[:, None, None, None]
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]
    x = init_x
    with torch.no_grad():
        for time_step in tqdm(time_steps):
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g ** 2)[:, None, None, None] * score_model(x, batch_time_step, cond=y, output_dict=False) * step_size
            x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
            # Do not include any noise in the last sampling step.
    return mean_x
    
# Load the trained model and embedding
device = 'cuda'
unet_face = UNet_SD(in_channels=3,
                    base_channels=128,
                    time_emb_dim=256,
                    context_dim=256,
                    multipliers=(1, 1, 2),
                    attn_levels=(1, 2, ),
                    nResAttn_block=1).to(device)

cond_embed = nn.Embedding(40 + 1, 256, padding_idx=40).to(device)

unet_face.load_state_dict(torch.load('fort2/ckpt_unet_SD_face.pth'))
cond_embed.load_state_dict(torch.load('fort2/ckpt_unet_SD_face_cond_embed.pth'))

# Define specific labels (replace with your desired attributes)
# Example: Let's say 0 - '5_o_Clock_Shadow', 1 - 'Arched_Eyebrows', ..., 39 - 'Wearing_Necktie'
# Here we assume 'smiling' is the 31st attribute, and 'wearing glasses' is the 15th attribute.
# modify this according to the specific attributes you have.
desired_labels = torch.zeros(64, 40, dtype=torch.int)  # Batch of 64 samples
desired_labels[:, 31] = 1  # Set 'smiling' attribute to True
desired_labels[:, 15] = 1  # Set 'wearing glasses' attribute to True

# Embed the labels
y_emb = cond_embed(desired_labels.to(device))

# Function to generate and visualize sample images
def generate_and_visualize_samples(model, y_emb, device='cuda', num_samples=64, num_steps=250):
    samples = Euler_Maruyama_sampler(model,
                                     marginal_prob_std_fn,
                                     diffusion_coeff_fn,
                                     num_samples,
                                     x_shape=(3, 32, 32),
                                     num_steps=num_steps,
                                     device=device,
                                     y=y_emb)
    
    # Denormalize the samples
    denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225],
                            [1/0.229, 1/0.224, 1/0.225])
    samples = denormalize(samples).clamp(0.0, 1.0)
    sample_grid = make_grid(samples, nrow=int(math.sqrt(num_samples)))
    
    # Visualize the samples
    plt.figure(figsize=(8, 8))
    plt.axis('off')
    plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
    plt.tight_layout()
    plt.show()

# Generate and visualize samples with the desired labels
generate_and_visualize_samples(unet_face, y_emb)
