In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np

import json
import cv2

# Data

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!rm -rf /content/domain_set/

In [4]:
!rm -rf __MACOSX/

In [5]:
!unzip -q /content/drive/MyDrive/domain_set.zip -d /content/

In [6]:
class MyDataset(Dataset):
    def __init__(self, folder_name):
        self.data = []
        self.folder = folder_name
        with open(f"./{folder_name}/prompt.json", 'rt') as f:
                self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        image_name = item['target']
        prompt = item['prompt']

        image = cv2.imread(f"./{self.folder}/" + image_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = (image.astype(np.float32) / 127.5) - 1.0

        return dict(image=image, label_tr=f"A [V] {prompt}", label_so=f"{prompt}")

In [7]:
nagai_dataset = MyDataset("domain_set")

In [8]:
train_dataloader = DataLoader(nagai_dataset, batch_size=4, shuffle=True)

# Objectives

In [9]:
class Loss_Simple(nn.Module):
  def __init__(self):
    super(Loss_Simple, self).__init__()

  def forward(self, y_ada, y_pretrain):
    diff = y_ada - y_pretrain
    return torch.mean(diff * diff)

In [10]:
class Loss_PR(nn.Module):
  def __init__(self):
    super(Loss_PR, self).__init__()

  def forward(self, y_pred_sou, y_pred_ada):
    diff = y_pred_sou - y_pred_ada
    return torch.mean(diff * diff)

In [11]:
class PairwiseSimilarityLoss(nn.Module):
    def __init__(self, D):
        super(PairwiseSimilarityLoss, self).__init__()
        self.D = D
        self.sim = F.cosine_similarity
        self.sfm = F.softmax

    def forward(self, z_ada, z_sou):
      # Assume D represents a function that computes the denoised latent codes
      denoised_ada = self.D(z_ada).sample
      denoised_sou = self.D(z_sou).sample

      # Calculate similarity scores using broadcasting for all unique pairs
      # Expand dims to (batch_size, 1, features) for ada and (1, batch_size, features) for sou
      # to compute pairwise similarity
      # Reshape or flatten tensors to 2D (B, C*W*H)
      denoised_ada_flat = denoised_ada.reshape(denoised_ada.size(0), -1)
      denoised_sou_flat = denoised_sou.reshape(denoised_sou.size(0), -1)

      # Compute cosine similarity for all pairs (B, B)
      sim_matrix_ada = self.sim(denoised_ada_flat.unsqueeze(1), denoised_ada_flat.unsqueeze(0), dim=2)
      sim_matrix_sou = self.sim(denoised_sou_flat.unsqueeze(1), denoised_sou_flat.unsqueeze(0), dim=2)

      # Mask out the self-similarity (diagonal elements of the similarity matrix)
      mask = torch.eye(sim_matrix_ada.size(0)).bool().to(sim_matrix_ada.device)
      sim_matrix_ada.masked_fill_(mask, float('-inf'))
      sim_matrix_sou.masked_fill_(mask, float('-inf'))

      # Apply softmax to the non-diagonal elements to get the probabilities
      p_ada = self.sfm(sim_matrix_ada, dim=1)
      p_sou = self.sfm(sim_matrix_sou, dim=1)

      # Calculate KL divergence
      kl_divergence = F.kl_div(p_ada.log(), p_sou, reduction='batchmean')

      return kl_divergence

In [12]:
class HaarWaveletTransform(nn.Module):
    def __init__(self):
        super().__init__()
        # Define the 2D filters based on the outer product of the 1D filters
        # Adjust the filters to have 3 output channels and 3 input channels
        # The filters are of shape (out_channels, in_channels, height, width)
        self.register_buffer('lh_filter', torch.tensor([[[1.0, -1.0], [1.0, -1.0]]]).repeat(3, 3, 1, 1) / 2.0)
        self.register_buffer('hl_filter', torch.tensor([[[1.0, 1.0], [-1.0, -1.0]]]).repeat(3, 3, 1, 1) / 2.0)
        self.register_buffer('hh_filter', torch.tensor([[[-1.0, 1.0], [1.0, -1.0]]]).repeat(3, 3, 1, 1) / 2.0)

        # Adjust filters to have 3 input channels
        for i in range(3):
          self.lh_filter[i, :, :, :] *= torch.tensor(i == 0).float()
          self.hl_filter[i, :, :, :] *= torch.tensor(i == 1).float()
          self.hh_filter[i, :, :, :] *= torch.tensor(i == 2).float()

    def forward(self, x):
        # Ensure input x has a batch and channel dimension
        if x.ndim == 3:
            x = x.unsqueeze(1)  # Add a batch dimension if it's not present

        # Apply filters to input x
        lh = F.conv2d(x, self.lh_filter, stride=2)
        hl = F.conv2d(x, self.hl_filter, stride=2)
        hh = F.conv2d(x, self.hh_filter, stride=2)

        # Sum the results from each filter
        return lh + hl + hh



# Example usage:
# Assume `image` is a PyTorch tensor of shape (1, height, width) representing a grayscale image
# The image tensor should have a batch dimension as well, so the full shape would be (batch_size, 1, height, width)
#haar_wavelet_transform = HaarWaveletTransform()

# Apply the transform to the image
#lh, hl, hh = haar_wavelet_transform(image.unsqueeze(0))

In [13]:
class Loss_HF(nn.Module):
  def __init__(self, D):
    super(Loss_HF, self).__init__()
    self.hwt = HaarWaveletTransform()
    self.sim = F.cosine_similarity
    self.sfm = F.softmax
    self.D = D

  def forward(self, z_ada, z_sou):
      # Assume D represents a function that computes the denoised latent codes
      # use the transform to extract the high frequency (fine-grained) details
      denoised_ada = self.hwt(self.D(z_ada).sample)
      denoised_sou = self.hwt(self.D(z_sou).sample)

      # Calculate similarity scores using broadcasting for all unique pairs
      # Expand dims to (batch_size, 1, features) for ada and (1, batch_size, features) for sou
      # to compute pairwise similarity
      denoised_ada_flat = denoised_ada.reshape(denoised_ada.size(0), -1)
      denoised_sou_flat = denoised_sou.reshape(denoised_sou.size(0), -1)

      # Compute cosine similarity for all pairs (B, B)
      sim_matrix_ada = self.sim(denoised_ada_flat.unsqueeze(1), denoised_ada_flat.unsqueeze(0), dim=2)
      sim_matrix_sou = self.sim(denoised_sou_flat.unsqueeze(1), denoised_sou_flat.unsqueeze(0), dim=2)

      # Mask out the self-similarity (diagonal elements of the similarity matrix)
      mask = torch.eye(sim_matrix_ada.size(0)).bool().to(sim_matrix_ada.device)
      sim_matrix_ada.masked_fill_(mask, float('-inf'))
      sim_matrix_sou.masked_fill_(mask, float('-inf'))

      # Apply softmax to the non-diagonal elements to get the probabilities
      p_ada = self.sfm(sim_matrix_ada, dim=1)
      p_sou = self.sfm(sim_matrix_sou, dim=1)

      # Calculate KL divergence
      kl_divergence = F.kl_div(p_ada.log(), p_sou, reduction='batchmean')

      return kl_divergence

In [14]:
class Loss_HFMSE(nn.Module):
  def __init__(self, D):
    super(Loss_HFMSE, self).__init__()
    self.hwt = HaarWaveletTransform()
    self.D = D

  def forward(self, z_ada, x_init):
    diff = self.hwt(self.D(z_ada).sample) - self.hwt(x_init)
    return torch.mean(diff * diff)

In [15]:
class DomainLoss(nn.Module):
  def __init__(self, D):
    super(DomainLoss, self).__init__()
    #Loss Components
    self.l_simp = Loss_Simple()
    self.l_pr = Loss_PR()
    self.l_img = PairwiseSimilarityLoss(D)
    self.l_hf = Loss_HF(D)
    self.l_hfmse = Loss_HFMSE(D)

    #Loss Weights
    self.l1 = 1
    self.l2 = 2.5e+2
    self.l3 = 2.5e+2
    self.l4 = 0.6
  def forward(self, z_ada, z, z_pr_sou, z_pr_ada, x_init):
    return self.l_simp(z_ada, z) + self.l1 * self.l_pr(z_pr_sou, z_pr_ada) \
    + self.l2 * self.l_img(z_ada, z_pr_ada) + self.l3 * self.l_hf(z_ada, z_pr_ada) * self.l_hfmse(z_ada, x_init)

# Model

### Model Imports

In [16]:
!pip install --upgrade diffusers[torch]



### Component Instantiation

In [17]:
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler

In [18]:
model_id = "runwayml/stable-diffusion-v1-5"

In [19]:
Tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
Encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [20]:
VAE = AutoencoderKL.from_pretrained(model_id, subfolder="vae")

In [21]:
UNetLocked = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
UNetTrained = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")

In [22]:
Scheduler = LMSDiscreteScheduler(
    beta_start = 0.00085,
    beta_end = 0.012,
    beta_schedule = 'scaled_linear',
    num_train_timesteps = 1000
);

In [23]:
Loss = DomainLoss(VAE.decode)

In [24]:
batch_size = 4
num_epochs = 1250 // (batch_size * 2)

In [25]:
Optimizer = optim.Adam(UNetTrained.parameters(), lr=4e-6)

In [26]:
lr_scheduler = torch.optim.swa_utils.SWALR(Optimizer, anneal_strategy="linear", anneal_epochs=num_epochs, swa_lr=1.5e-6)

# Training

In [27]:
from accelerate import Accelerator
from tqdm.auto import tqdm
import os

## Training Utils

In [28]:
# from dataclasses import dataclass

#
# class TrainingConfig:
#     image_size = 128  # the generated image resolution
#     train_batch_size = 16
#     eval_batch_size = 16  # how many images to sample during evaluation
#     num_epochs = 50
#     gradient_accumulation_steps = 1
#     learning_rate = 1e-4
#     lr_warmup_steps = 500
#     save_image_epochs = 10
#     save_model_epochs = 30
#     mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
#     output_dir = 'ddpm-butterflies-128'  # the model namy locally and on the HF Hub

#     push_to_hub = False  # whether to upload the saved model to the HF Hub
#     hub_private_repo = False
#     overwrite_output_dir = True  # overwrite the old model when re-running the notebook
#     seed = 0

# config = TrainingConfig()

## Training Loop

### Boilerplate for later

In [29]:
# def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
#     # Initialize accelerator and tensorboard logging
#     accelerator = Accelerator(
#         mixed_precision=config.mixed_precision,
#         gradient_accumulation_steps=config.gradient_accumulation_steps,
#         log_with="tensorboard",
#         logging_dir=os.path.join(config.output_dir, "logs")
#     )
#     if accelerator.is_main_process:
#         accelerator.init_trackers("train_example")

#     # Prepare everything
#     # There is no specific order to remember, you just need to unpack the
#     # objects in the same order you gave them to the prepare method.
#     model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
#         model, optimizer, train_dataloader, lr_scheduler
#     )

#     global_step = 0

#     # Now you train the model
#     for epoch in range(config.num_epochs):
#         progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
#         progress_bar.set_description(f"Epoch {epoch}")

#         for step, batch in enumerate(train_dataloader):
#             images = batch['images']
#             x_pr = torch.randn(images.shape).to(images.device)

#             #create c_tar using clip
#             labels_tr = batch['labels_tr']
#             tokens_tr = Tokenizer(labels_tr, padding=True, return_tensors="pt")
#             c_tar = Encoder(**tokens_tr).last_hidden_state

#             #create c_sou using clip
#             labels_so = batch['labels_so']
#             tokens_so = Tokenizer(labels_so, padding=True, return_tensors="pt")
#             c_sou = Encoder(**tokens_so).last_hidden_state

#             # Sample noise to add to the images
#             z = VAE.encode(images)
#             z_pr = VAE.encode(x_pr)
#             noise = torch.randn(z.shape).to(z.device)
#             bs = z.shape[0]

#             # Sample a random timestep for each image
#             timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=z.device).long()

#             # Add noise to the clean images according to the noise magnitude at each timestep
#             # (this is the forward diffusion process)
#             z_t = noise_scheduler.add_noise(z, noise, timesteps)

#             # Get the random noise z_pr_t
#             z_pr_t = noise_scheduler.add_noise(z_pr, noise, timesteps)

#             with torch.no_grad():
#                 # Predict the noise residual
#                 z_pr_sou = UNetLocked(z_pr_t, timesteps, c_sou)

#             with accelerator.accumulate(model):
#                 # Predict the noise residual
#                 z_ada = model(z_t, timesteps, c_tar)["sample"]
#                 z_pr_ada = model(z_pr_t, timesteps, c_sou)["sample"]
#                 loss = Loss(z_ada, z, z_pr_sou, z_pr_ada, images)
#                 accelerator.backward(loss)

#                 accelerator.clip_grad_norm_(model.parameters(), 1.0)
#                 optimizer.step()
#                 lr_scheduler.step()
#                 optimizer.zero_grad()

#             progress_bar.update(1)
#             logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
#             progress_bar.set_postfix(**logs)
#             accelerator.log(logs, step=global_step)
#             global_step += 1

### Basic Training Loop

In [30]:
# Now you train the model
global_step = 0
for epoch in range(1250 // batch_size):
    progress_bar = tqdm(total=len(train_dataloader))
    progress_bar.set_description(f"Epoch {epoch}")

    for step, batch in enumerate(train_dataloader):
        images = batch['image']
        images = images.permute(0, 3, 1, 2)
        x_pr = torch.randn(images.shape).to(images.device)

        #create c_tar using clip
        labels_tr = batch['label_tr']
        tokens_tr = Tokenizer(labels_tr, padding=True, return_tensors="pt")
        c_tar = Encoder(**tokens_tr).last_hidden_state

        #create c_sou using clip
        labels_so = batch['label_so']
        tokens_so = Tokenizer(labels_so, padding=True, return_tensors="pt")
        c_sou = Encoder(**tokens_so).last_hidden_state

        # Sample noise to add to the images
        z = VAE.encode(images).latent_dist.sample()
        z_pr = VAE.encode(x_pr).latent_dist.sample()
        noise = torch.randn(z.shape).to(z.device)
        bs = z.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(0, Scheduler.num_train_timesteps, (bs,), device=z.device).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        z_t = Scheduler.add_noise(z, noise, timesteps)

        # Get the random noise z_pr_t
        z_pr_t = Scheduler.add_noise(z_pr, noise, timesteps)

        with torch.no_grad():
            # Predict the noise residual
            z_pr_sou = UNetLocked(z_pr_t, timesteps, c_sou)["sample"]

        # Predict the noise residual
        z_ada = UNetTrained(z_t, timesteps, c_tar)["sample"]
        z_pr_ada = UNetTrained(z_pr_t, timesteps, c_sou)["sample"]
        loss = Loss(z_ada, z, z_pr_sou, z_pr_ada, images)
        loss.backward()
        Optimizer.step()
        lr_scheduler.step()
        Optimizer.zero_grad()

        progress_bar.update(1)
        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
        progress_bar.set_postfix(**logs)
        global_step += 1

  0%|          | 0/2 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


KeyboardInterrupt: 