In [82]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Data

# Objectives

In [70]:
class Loss_Simple(nn.Module):
  def __init__(self):
    super(Loss_Simple, self).__init__()

  def forward(self, y_ada, y_pretrain):
    diff = y_ada - y_pretrain
    return torch.mean(diff * diff)

In [71]:
class Loss_PR(nn.Module):
  def __init__(self):
    super(Loss_PR, self).__init__()

  def forward(self, y_pred_sou, y_pred_ada):
    diff = y_pred_sou - y_pred_ada
    return torch.mean(diff * diff)

In [72]:
class PairwiseSimilarityLoss(nn.Module):
    def __init__(self, D):
        super(PairwiseSimilarityLoss, self).__init__()
        self.D = D
        self.sim = F.cosine_similarity
        self.sfm = F.softmax

    def forward(self, z_ada, z_sou):
      # Assume D represents a function that computes the denoised latent codes
      denoised_ada = self.D(z_ada)
      denoised_sou = self.D(z_sou)

      # Calculate similarity scores using broadcasting for all unique pairs
      # Expand dims to (batch_size, 1, features) for ada and (1, batch_size, features) for sou
      # to compute pairwise similarity
      ada_expanded = denoised_ada.unsqueeze(1)
      sou_expanded = denoised_sou.unsqueeze(0)

      # Compute cosine similarity for all pairs (batch_size, batch_size)
      sim_matrix_ada = self.sim(ada_expanded, denoised_ada.unsqueeze(0))
      sim_matrix_sou = self.sim(sou_expanded, denoised_sou.unsqueeze(1))

      # Mask out the self-similarity (diagonal elements of the similarity matrix)
      mask = torch.eye(sim_matrix_ada.size(0)).bool().to(sim_matrix_ada.device)
      sim_matrix_ada.masked_fill_(mask, float('-inf'))
      sim_matrix_sou.masked_fill_(mask, float('-inf'))

      # Apply softmax to the non-diagonal elements to get the probabilities
      p_ada = self.sfm(sim_matrix_ada, dim=1)
      p_sou = self.sfm(sim_matrix_sou, dim=1)

      # Calculate KL divergence
      kl_divergence = F.kl_div(p_ada.log(), p_sou, reduction='batchmean')

      return kl_divergence

In [73]:
class HaarWaveletTransform(nn.Module):
    def __init__(self):
        super().__init__()
        # Define the 2D filters based on the outer product of the 1D filters
        # The filters are of shape (out_channels, in_channels, height, width)
        # LH filter detects horizontal edges
        self.register_buffer('lh_filter', torch.tensor([[[1.0, -1.0], [1.0, -1.0]]]) / 2.0)
        # HL filter detects vertical edges
        self.register_buffer('hl_filter', torch.tensor([[[1.0, 1.0], [-1.0, -1.0]]]) / 2.0)
        # HH filter detects diagonal edges
        self.register_buffer('hh_filter', torch.tensor([[[-1.0, 1.0], [1.0, -1.0]]]) / 2.0)

    def forward(self, x):
        # Ensure input x has a channel dimension
        if x.ndim == 2:
            x = x.unsqueeze(0)  # Add a channel dimension

        # Apply filters to input x
        lh = F.conv2d(x, self.lh_filter, stride=2)
        hl = F.conv2d(x, self.hl_filter, stride=2)
        hh = F.conv2d(x, self.hh_filter, stride=2)

        return lh + hl + hh

# Example usage:
# Assume `image` is a PyTorch tensor of shape (1, height, width) representing a grayscale image
# The image tensor should have a batch dimension as well, so the full shape would be (batch_size, 1, height, width)
#haar_wavelet_transform = HaarWaveletTransform()

# Apply the transform to the image
#lh, hl, hh = haar_wavelet_transform(image.unsqueeze(0))

In [74]:
class Loss_HF(nn.Module):
  def __init__(self, D):
    super(Loss_HF, self).__init__()
    self.hwt = HaarWaveletTransform()
    self.sim = F.cosine_similarity
    self.sfm = F.softmax

  def forward(self, z_ada, z_sou):
      # Assume D represents a function that computes the denoised latent codes
      # use the transform to extract the high frequency (fine-grained) details
      denoised_ada = self.hwt(self.D(z_ada))
      denoised_sou = self.hwt(self.D(z_sou))

      # Calculate similarity scores using broadcasting for all unique pairs
      # Expand dims to (batch_size, 1, features) for ada and (1, batch_size, features) for sou
      # to compute pairwise similarity
      ada_expanded = denoised_ada.unsqueeze(1)
      sou_expanded = denoised_sou.unsqueeze(0)

      # Compute cosine similarity for all pairs (batch_size, batch_size)
      sim_matrix_ada = self.sim(ada_expanded, denoised_ada.unsqueeze(0))
      sim_matrix_sou = self.sim(sou_expanded, denoised_sou.unsqueeze(1))

      # Mask out the self-similarity (diagonal elements of the similarity matrix)
      mask = torch.eye(sim_matrix_ada.size(0)).bool().to(sim_matrix_ada.device)
      sim_matrix_ada.masked_fill_(mask, float('-inf'))
      sim_matrix_sou.masked_fill_(mask, float('-inf'))

      # Apply softmax to the non-diagonal elements to get the probabilities
      p_ada = self.sfm(sim_matrix_ada, dim=1)
      p_sou = self.sfm(sim_matrix_sou, dim=1)

      # Calculate KL divergence
      kl_divergence = F.kl_div(p_ada.log(), p_sou, reduction='batchmean')

      return kl_divergence

In [78]:
class Loss_HFMSE(nn.Module):
  def __init__(self, D):
    super(Loss_HFMSE, self).__init__()
    self.hwt = HaarWaveletTransform()
    self.D = D

  def forward(self, z_ada, x_init):
    diff = self.hwt(self.D(z_ada)) - self.hwt(x_init)
    return torch.mean(diff * diff)

In [76]:
class DomainLoss(nn.Module):
  def __init__(self, D):
    super(DomainLoss, self).__init__()
    #Loss Components
    self.l_simp = Loss_Simple()
    self.l_pr = Loss_PR()
    self.l_img = PairwiseSimilarityLoss(D)
    self.l_hf = Loss_HF(D)
    self.l_hfmse = Loss_HFMSE(D)

    #Loss Weights
    self.l1 = 1
    self.l2 = 2.5e+2
    self.l3 = 2.5e+2
    self.l4 = 0.6
  def forward(self, z_ada, z, z_pr_sou, z_pr_ada, x_init):
    return self.l_simp(z_ada, z) + self.l1 * self.l_pr(z_pr_sou, z_pr_ada) \
    + self.l2 * self.l_img(z_ada, z_pr_ada) + self.l3 * self.l_hf(z_ada, z_pr_ada) * self.l_hfmse(z_ada, x_init)

# Model

### Model Imports

In [37]:
!pip install --upgrade diffusers[torch]



### Component Instantiation

In [38]:
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler

In [39]:
model_id = "runwayml/stable-diffusion-v1-5"

In [40]:
Tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
Encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

tokenizer_config.json:   0%|          | 0.00/905 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/961k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.52k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

In [41]:
VAE = AutoencoderKL.from_pretrained(model_id, subfolder="vae")

vae/config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

In [42]:
UNetLocked = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
UNetTrained = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")

In [43]:
Scheduler = LMSDiscreteScheduler(
    beta_start = 0.00085,
    beta_end = 0.012,
    beta_schedule = 'scaled_linear',
    num_train_timesteps = 1000
);

In [85]:
Loss = DomainLoss(VAE.decode)

In [89]:
batch_size = 4
num_epochs = 1250 // (batch_size * 2)

In [90]:
Optimizer = optim.Adam(UNetTrained.parameters(), lr=4e-6)

In [91]:
lr_scheduler = torch.optim.swa_utils.SWALR(Optimizer, anneal_strategy="linear", anneal_epochs=num_epochs, swa_lr=1.5e-6)

# Training

In [32]:
from accelerate import Accelerator
from tqdm.auto import tqdm
import os

## Training Utils

In [33]:
# from dataclasses import dataclass

#
# class TrainingConfig:
#     image_size = 128  # the generated image resolution
#     train_batch_size = 16
#     eval_batch_size = 16  # how many images to sample during evaluation
#     num_epochs = 50
#     gradient_accumulation_steps = 1
#     learning_rate = 1e-4
#     lr_warmup_steps = 500
#     save_image_epochs = 10
#     save_model_epochs = 30
#     mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
#     output_dir = 'ddpm-butterflies-128'  # the model namy locally and on the HF Hub

#     push_to_hub = False  # whether to upload the saved model to the HF Hub
#     hub_private_repo = False
#     overwrite_output_dir = True  # overwrite the old model when re-running the notebook
#     seed = 0

# config = TrainingConfig()

## Training Loop

### Boilerplate for later

In [80]:
# def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
#     # Initialize accelerator and tensorboard logging
#     accelerator = Accelerator(
#         mixed_precision=config.mixed_precision,
#         gradient_accumulation_steps=config.gradient_accumulation_steps,
#         log_with="tensorboard",
#         logging_dir=os.path.join(config.output_dir, "logs")
#     )
#     if accelerator.is_main_process:
#         accelerator.init_trackers("train_example")

#     # Prepare everything
#     # There is no specific order to remember, you just need to unpack the
#     # objects in the same order you gave them to the prepare method.
#     model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
#         model, optimizer, train_dataloader, lr_scheduler
#     )

#     global_step = 0

#     # Now you train the model
#     for epoch in range(config.num_epochs):
#         progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
#         progress_bar.set_description(f"Epoch {epoch}")

#         for step, batch in enumerate(train_dataloader):
#             images = batch['images']
#             x_pr = torch.randn(images.shape).to(images.device)

#             #create c_tar using clip
#             labels_tr = batch['labels_tr']
#             tokens_tr = Tokenizer(labels_tr, padding=True, return_tensors="pt")
#             c_tar = Encoder(**tokens_tr).last_hidden_state

#             #create c_sou using clip
#             labels_so = batch['labels_so']
#             tokens_so = Tokenizer(labels_so, padding=True, return_tensors="pt")
#             c_sou = Encoder(**tokens_so).last_hidden_state

#             # Sample noise to add to the images
#             z = VAE.encode(images)
#             z_pr = VAE.encode(x_pr)
#             noise = torch.randn(z.shape).to(z.device)
#             bs = z.shape[0]

#             # Sample a random timestep for each image
#             timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=z.device).long()

#             # Add noise to the clean images according to the noise magnitude at each timestep
#             # (this is the forward diffusion process)
#             z_t = noise_scheduler.add_noise(z, noise, timesteps)

#             # Get the random noise z_pr_t
#             z_pr_t = noise_scheduler.add_noise(z_pr, noise, timesteps)

#             with torch.no_grad():
#                 # Predict the noise residual
#                 z_pr_sou = UNetLocked(z_pr_t, timesteps, c_sou)

#             with accelerator.accumulate(model):
#                 # Predict the noise residual
#                 z_ada = model(z_t, timesteps, c_tar)["sample"]
#                 z_pr_ada = model(z_pr_t, timesteps, c_sou)["sample"]
#                 loss = Loss(z_ada, z, z_pr_sou, z_pr_ada, images)
#                 accelerator.backward(loss)

#                 accelerator.clip_grad_norm_(model.parameters(), 1.0)
#                 optimizer.step()
#                 lr_scheduler.step()
#                 optimizer.zero_grad()

#             progress_bar.update(1)
#             logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
#             progress_bar.set_postfix(**logs)
#             accelerator.log(logs, step=global_step)
#             global_step += 1

### Basic Training Loop

In [None]:
# Now you train the model
global_step = 0
for epoch in range(1250 // batch_size):
    progress_bar = tqdm(total=len(train_dataloader))
    progress_bar.set_description(f"Epoch {epoch}")

    for step, batch in enumerate(train_dataloader):
        images = batch['images']
        x_pr = torch.randn(images.shape).to(images.device)

        #create c_tar using clip
        labels_tr = batch['labels_tr']
        tokens_tr = Tokenizer(labels_tr, padding=True, return_tensors="pt")
        c_tar = Encoder(**tokens_tr).last_hidden_state

        #create c_sou using clip
        labels_so = batch['labels_so']
        tokens_so = Tokenizer(labels_so, padding=True, return_tensors="pt")
        c_sou = Encoder(**tokens_so).last_hidden_state

        # Sample noise to add to the images
        z = VAE.encode(images)
        z_pr = VAE.encode(x_pr)
        noise = torch.randn(z.shape).to(z.device)
        bs = z.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(0, Scheduler.num_train_timesteps, (bs,), device=z.device).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        z_t = Scheduler.add_noise(z, noise, timesteps)

        # Get the random noise z_pr_t
        z_pr_t = Scheduler.add_noise(z_pr, noise, timesteps)

        with torch.no_grad():
            # Predict the noise residual
            z_pr_sou = UNetLocked(z_pr_t, timesteps, c_sou)

        # Predict the noise residual
        z_ada = UNetTrained(z_t, timesteps, c_tar)["sample"]
        z_pr_ada = UNetTrained(z_pr_t, timesteps, c_sou)["sample"]
        loss = Loss(z_ada, z, z_pr_sou, z_pr_ada, images)
        loss.backward()
        Optimizer.step()
        lr_scheduler.step()
        Optimizer.zero_grad()

        progress_bar.update(1)
        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
        progress_bar.set_postfix(**logs)
        global_step += 1