In [None]:
# get the contents of the github so we can import the FM model
!git clone https://github.com/atong01/conditional-flow-matching.git
%cd conditional-flow-matching
!pip install -r requirements.txt
!pip install torchdiffeq

In [None]:
# ------------------------------------------------------------------------------
# Imports
# ------------------------------------------------------------------------------

# Standard Imports
import os
import numpy as np
import tqdm
import yaml
from typing import KeysView


# Data Storage imports
from google.colab import drive
drive.mount('/content/drive')

# DL imports
import torch
import torchvision
from torchvision.transforms import GaussianBlur
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchmetrics.image import lpip, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure # lpip, PSNR, SSIM


# ODE imports
from torchdyn.core import NeuralODE
import torchdiffeq

# Plot imports
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
import PIL.Image
from IPython.display import display
from PIL import Image

# Github imports
from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models import MLP
from torchcfm.models.unet.unet import UNetModelWrapper

# ------------------------------------------------------------------------------
# Global Constants
# ------------------------------------------------------------------------------
# Split into sub dictionaries for better organisation, or some other config type like yaml files.

CONFIG = {
    'flow_model': 'otcfm',
    'model_variant': 'normal', # normal or ema
    'output_dir': './results/',
    'models_dir': '/content/drive/MyDrive/KTH/DD2412 Deep Learning Adv/Project/flow_models/',
    'in_channels': 3,
    'sample_size': 32,
    'class_cond': False,    # cant use class conditioning currently as none of the pt files contains label_emb but the UNet contains this variable if class_cond==True
    'num_classes': 10,
    'y_class': 0,
    'N': 20,
    'corrupt_Type': 'inpainting',
    'ODE_steps': 100,
    'lr': 1,
    'lr_type': 'constant',
    'inpaint_percent': 0.9,
    'use_checkpointing': True,
    'frozen_model': False,
    'ODE_type': 'dopri5',
    'gamma': 1,
    'blend_param': 0.1,
    'FID_batch': 16,
    'FID_num_samples': 16,
    'num_images':5,
}

device = "cuda" if torch.cuda.is_available() else "cpu"

# Just here for now for debugging purposes, so even if d-flow crashes we can still access the intermediate results
x_0_list = []
image_list = []
losses = []


In [None]:
def save_output(path, _dict):
  torch.save(_dict, path)

def load_output(path):
  return torch.load(path, weights_only=False, map_location=device)

def save_config(config, save_dir):
  os.makedirs(save_dir, exist_ok=True)
  config_path = os.path.join(save_dir, 'config.yaml')
  with open(config_path, 'w') as f:
    yaml.dump(config, f)

  #images = np.load(os.path.join(output_dir, 'images.npy'))
  #losses = np.load(os.path.join(output_dir, 'losses.npy'))

def load_model(device):
  """
  Checkpoint contains:
    net_model
    ema_model
    sched
    optim
    step
  We only need net_model for inference, as the rest are for training the flow model.
  """
  # Create directory where all results will be stored
  save_dir = os.path.join(CONFIG['output_dir'], CONFIG['flow_model'])
  os.makedirs(save_dir, exist_ok=True)

  # Load checkpoint
  try:
    checkpoint_path = os.path.join(
      CONFIG['models_dir'], f"{CONFIG['flow_model']}_cifar10_weights_step_400000.pt"
    )
    print(f"Attempting to load {CONFIG['flow_model']} model from {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
  except NotImplementedError:
    raise NotImplementedError(
      f"Unknown model {CONFIG['flow_model']}, must be one of ['otcfm', 'icfm', 'fm']"
    )
  except FileNotFoundError:
    raise FileNotFoundError(
        f"Checkpoint file not found at {checkpoint_path}. Verify the path or the file's existence."
    )

  # Create UNet model and load in it's corresponding weights
  net_model = net_model = UNetModelWrapper(
    dim=(CONFIG['in_channels'], CONFIG['sample_size'], CONFIG['sample_size']),
    class_cond=CONFIG['class_cond'],
    num_classes=CONFIG['num_classes'],
    num_res_blocks=2,
    num_channels=128,
    channel_mult=[1, 2, 2, 2],
    num_heads=4,
    num_head_channels=64,
    attention_resolutions="16",
    dropout=0.1,
    use_checkpoint=CONFIG['use_checkpointing'],
  ).to(device)

  if CONFIG['model_variant'] == 'normal':
    net_model.load_state_dict(checkpoint['net_model'])
  elif CONFIG['model_variant'] == 'ema':
    net_model.load_state_dict(checkpoint['ema_model'])

  if CONFIG['frozen_model']:
    for name, param in net_model.named_parameters():
      param.requires_grad = False

  print(f"Succesfully loaded the {CONFIG['model_variant']} {CONFIG['flow_model']} model onto the {device}.")
  if device == 'cpu': print(f"Consider switching device to GPU, as it will run out of RAM during D-flow otherwise.")
  return net_model

In [None]:
# ------------------------------------------------------------------------------
# Corruptions
# ------------------------------------------------------------------------------
def inpainting_corruption(image):
  painted_image = torch.clone(image)
  painted_image[0, :, 14:18, 14:18] = 1
  return painted_image

def blur_corruption(image):
  #image_copy = torch.clone(image)
  filter_kernel = GaussianBlur(3, sigma=1)
  blurred_image = filter_kernel(image)
  return blurred_image

def resolution_corruption():
  pass

def corrupt_image(image, corruption_type):
  """
  Function to corrupt the image
  """
  if corruption_type == "blur":
    return blur_corruption(image)
  elif corruption_type == "inpainting":
    return inpainting_corruption(image)
  elif corruption_type == "resolution":
    return resolution_corruption(image)

# ------------------------------------------------------------------------------
# Loss Functions
# ------------------------------------------------------------------------------
def get_mask(x, percentage=0.9):
  _,_, H, W = x.shape

  # Create a mask that is True outside the painted region
  mask = torch.ones((H, W))
  mask[14:18, 14:18] = 0

  # Flatten H and W to one dimension so we can more easily subsample from them
  mask_flat = mask.view(H * W)

  # Get indices of valid (unpainted) pixels and sample from them
  valid_indices = torch.nonzero(mask_flat).squeeze()
  sampled_indices = np.random.choice(valid_indices.numpy(), size=int(len(valid_indices) * percentage), replace=False)
  return sampled_indices

def inpainting_loss(x,y,mask):
  B, C, H, W = x.shape

  # Flatten the image as the mask is flat
  x_flat, y_flat = x.view(B, C, H*W), y.view(B, C, H*W)

  # Retrieve the values at the sampled indices
  x_sampled_pixels = x_flat[:, :, mask]
  y_sampled_pixels = y_flat[:, :, mask]

  # Return the loss
  return torch.mean((x_sampled_pixels - y_sampled_pixels)**2)

# ------------------------------------------------------------------------------
# Plots
# ------------------------------------------------------------------------------
def display_sample(sample, i):
  image_processed = sample.cpu().permute(0, 2, 3, 1)
  image_processed = (image_processed + 1.0) * 127.5
  image_processed = image_processed.numpy().astype(np.uint8)

  image_pil = PIL.Image.fromarray(image_processed[0])
  display(f"Image at step {i}")
  display(image_pil)

def plot_d_flow_process(x_1_list, img_per_row):
  test_images = []
  for i, images in enumerate(x_1_list):
    images = images[-1].view([3, 32, 32]).clip(-1, 1)
    images = images / 2 + 0.5
    test_images.append(images)

  grid = make_grid(test_images, value_range=(0, 1), nrow=img_per_row, padding=2)
  img = ToPILImage()(grid)
  plt.imshow(img)
  plt.axis("off")
  plt.show()

def plot_losses(losses):
  all_loses = []
  for sublist in losses:
    for loss in sublist:
      all_loses.append(loss.item())

  # Plot the loss curve
  plt.plot(all_loses)
  plt.xlabel("Iteration")
  plt.ylabel("Loss")
  plt.title(f"Loss vs Iteration")
  plt.show()

def denormalize_images(images):
  if isinstance(images, list):
    if images[0].ndim == 4:
      images = [img.squeeze(0) for img in images]
    return [img / 2 + 0.5 for img in images]
  return images / 2 + 0.5

def create_img_grid(images, _nrow=4, _padding=2, plot=False):
  grid = make_grid(denormalize_images(images), value_range=(0,1), nrow=_nrow, padding=_padding)
  if plot:
    img = ToPILImage()(grid)
    plt.imshow(img)
    plt.axis("off")
    plt.show()
  return grid

In [None]:
def generate_x_0(model, corrupted_y):
  x_0 = torch.randn(
   1, CONFIG['in_channels'], CONFIG['sample_size'], CONFIG['sample_size'], device=device, requires_grad=True
  )
  ODE_func = lambda t, x: model.forward(t, x)

  with torch.no_grad():
    x_0 = torchdiffeq.odeint(
      func=ODE_func,
      y0=corrupted_y,
      t=torch.linspace(1, 0, CONFIG['ODE_steps'], device=device),
      atol=1e-4,
      rtol=1e-4,
      method=CONFIG['ODE_type'],
    )[-1].to(device) # only keep the final value from the ode solver

  x_0 = CONFIG['blend_param'] **0.5 * (corrupted_y + x_0) + (1. - CONFIG['blend_param'])**0.5*torch.randn_like(x_0)
  x_0.requires_grad=True
  return x_0

# ------------------------------------------------------------------------------
# Our d-flow algorithm implementation
# ------------------------------------------------------------------------------
from torch.utils.tensorboard import SummaryWriter

def setup_experiment_dirs(base_dir, experiment_id):
  log_dir = os.path.join(base_dir, f'experiment_{experiment_id}/logs')
  output_dir = os.path.join(base_dir, f'experiment_{experiment_id}')
  os.makedirs(log_dir, exist_ok=True)
  os.makedirs(output_dir, exist_ok=True)
  return log_dir, output_dir

def save_PIL_images(path, output_dict):
  os.makedirs(path, exist_ok=True)
  D_flow_images = [output_dict[image]['x_1_list'][-1][-1] for image in output_dict.keys()]
  denormalized_D_flow_images = denormalize_images(D_flow_images)

  for idx, image_tensor in enumerate(denormalized_D_flow_images):
    pil_image = TF.to_pil_image(image_tensor.squeeze(0))

    image_file_path = os.path.join(path, f'd_flow_image_{idx}.jpg')
    pil_image.save(image_file_path)

def evaluate_d_flow(model, N, corruption_type='inpainting'):
  # Setup experiment directories
  base_dir = '/content/drive/MyDrive/KTH/DD2412 Deep Learning Adv/Project/Runs'
  os.makedirs(base_dir, exist_ok=True)

  experiment_id = len(os.listdir(base_dir)) + 2
  print(f"Experiment {experiment_id}")
  log_dir, output_dir = setup_experiment_dirs(base_dir, experiment_id)
  config_file = os.path.join(base_dir, f'experiment_{experiment_id}/config.yaml')

  # Save configuration
  save_config(CONFIG, config_file)

  # TensorBoard logging
  writer = SummaryWriter(log_dir)

  output_dict = {}
  for i in range(CONFIG['num_images']):
    output_dict[i] = {}

    # Run D-flow algorithm
    y = loaded_single_class_images[i].view(1,3,32,32).to(device)

    # Store and log ground truth Y
    test_y = y.clone()
    test_y = test_y.view([3, 32, 32]).clip(-1, 1)
    output_dict[i]['y'] = test_y
    writer.add_image('Ground Truth', denormalize_images(test_y.squeeze(0)), i)

    # Run D-flow
    x_0_list, image_list, losses = d_flow(model, y, N, writer, output_dict, corruption_type, img_num=i)

    # Log results
    grid = create_img_grid([images[-1].view([3,32,32]).clip(-1,1) for images in image_list+[y]], _nrow=3, _padding=2)

    writer.add_image('D-Flow Step Images', grid, i)

  save_output(os.path.join(output_dir, 'output_dict.pt'), output_dict)
  save_PIL_images(os.path.join(output_dir, 'd_flow_images'), output_dict)

def d_flow(model, y, N, writer, output_dict, corruption_type='inpainting', img_num=0):
  # Create corrupted y and x_0
  corrupted_y = corrupt_image(y, corruption_type).to(device)
  corrupted_y.requires_grad = False

  # Store and log corrupted Y
  test_corrupted_y = corrupted_y.clone()
  test_corrupted_y = test_corrupted_y.view([3, 32, 32]).clip(-1, 1)

  output_dict[img_num]["Corrupted Y"] = test_corrupted_y
  writer.add_image('Corrupted Y', denormalize_images(test_corrupted_y), img_num)


  ODE_func = lambda t, x: model.forward(t, x)

  x_0 = generate_x_0(model, corrupted_y)

  optim = torch.optim.LBFGS(
    [x_0],
    lr=CONFIG['lr'],
    max_iter=CONFIG['max_iter'],
    history_size=20,
    tolerance_grad=1e-7,
    tolerance_change=1e-9,
    line_search_fn="strong_wolfe"
  )
  scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=CONFIG['gamma'])

  # Get fixed corruption function
  mask = get_mask(corrupted_y, percentage=CONFIG['inpaint_percent'])

  x_0_list, image_list, losses = [], [], []

  closure_counter = 0
  # Run D-flow algorithm
  for i in tqdm.tqdm(range(N)):
    iter_losses, iter_x_0s, iter_x_1s = [], [], []
    def closure():
      optim.zero_grad()

      # Solve ODE
      x_1 = torchdiffeq.odeint(
        func=ODE_func,
        y0=x_0,
        t=torch.linspace(0, 1, CONFIG['ODE_steps'], device=device),
        atol=1e-4,
        rtol=1e-4,
        method=CONFIG['ODE_type'],
      )[-1] # only keep the final value from the ode solver

      # Compute loss
      loss = inpainting_loss(x_1, corrupted_y, mask)

      # Add elements to lists for debugg and plots
      iter_losses.append(loss.item())
      iter_x_0s.append(x_0.detach())
      iter_x_1s.append(x_1.detach())

      loss.backward()

      torch.nn.utils.clip_grad_norm_([x_0], max_norm=1)
      return loss

    optim.step(closure)
    scheduler.step()

    losses.append(iter_losses)
    x_0_list.append(iter_x_0s)
    image_list.append(iter_x_1s)

    # Logging code
    grid = create_img_grid([image.view([3,32,32]).clip(-1,1) for image in iter_x_1s])

    writer.add_image(f'Closure Images/img_{img_num}', grid, closure_counter)
    closure_counter += 1

  output_dict[img_num]['x_0_list'] = x_0_list
  output_dict[img_num]['x_1_list'] = image_list
  output_dict[img_num]['losses'] = losses
  return x_0_list, image_list, losses


In [None]:
%load_ext tensorboard
%tensorboard --logdir "/content/drive/MyDrive/KTH/DD2412 Deep Learning Adv/Project/Runs"

In [None]:
# Change CONFIG as needed, then run
net_model=load_model(device)

evaluate_d_flow(net_model, 5, corruption_type=CONFIG['corrupt_Type'])

In [None]:
def load_output(path, device):
  return torch.load(path, weights_only=False, map_location=device)

def eval_model(true_data_dir, d_flow_output_path, N_patches=8):
  # output_dict[img_num]['x_0_list'], utput_dict[img_num]['x_1_list'], output_dict[img_num]['losses'],  output_dict[img_num]['y'],  output_dict[img_num]['Corrupted Y']
  single_class_images_path = os.path.join(true_data_dir, "single_class_images.pt")

  output_dict = load_output(d_flow_output_path, "cpu")
  single_class_images = load_output(single_class_images_path, "cpu")

  final_d_flow_images = [output_dict[image]['x_1_list'][-1][-1].clip(-1,1) for image in output_dict.keys()]
  single_class_images = [img.unsqueeze(0) for img in single_class_images.values()]

  lpip_metric = lpip.LearnedPerceptualImagePatchSimilarity(net_type='squeeze').to(device)
  psnr_metric = PeakSignalNoiseRatio().to(device)
  ssim_metric = StructuralSimilarityIndexMeasure(data_range=2.0).to(device)  # Data range for [-1, 1]

  for image_pair in zip(final_d_flow_images, single_class_images):
    patches_img_1 = image_pair[0].reshape(4,3, 16, 16)
    patches_img_2 = image_pair[1].reshape(4,3, 16, 16)

    lpip_metric.update(*image_pair)
    psnr_metric.update(*image_pair)
    ssim_metric.update(*image_pair)

  eval_dict = {
    "LPIPS_sum_total": lpip_metric.sum_scores,
    "LPIPS_total": lpip_metric.total,
    "SSIM similarity": ssim_metric.similarity / 5,  # sums up the similariy of all images, so / 5 is the mean
    "SSIM total": ssim_metric.total,
    "PSNR sum_squared_error": psnr_metric.sum_squared_error / 5, # sums up the similariy of all images, so / 5 is the mean
    "PSNR total": psnr_metric.total,
    "PSNR value": 20 * np.log10(255) - 10 * np.log10(psnr_metric.sum_squared_error / 5)
  }

  return eval_dict

# Specify the experiments you want to evaluate
exp_nums=[97,84,122,247,134,140, 135,139]
final_images = []
PATH = "PATH"
DATA_PATH = "DATA_PATH"

# Get GT and corrupted GT
loaded_output = load_output(PATH, device)
for key in loaded_output.keys():
  final_images.append(loaded_output[key]['y'].unsqueeze(0))
for key in loaded_output.keys():
  final_images.append(loaded_output[key]['Corrupted Y'].unsqueeze(0))

# Evaluate experiments and add final d-flow images
for i, exp in enumerate(exp_nums):
  print("\n----------------------------------")
  print(f"Experiment {exp}")
  eval_dict =eval_model(DATA_PATH, DATA_PATH)
  for key in eval_dict:
    print(f"{key:} {eval_dict[key].f}")
  loaded_output = load_output(PATH, device)
  for j in range(5):
    final_images.append(loaded_output[j]['x_1_list'][-1][-1].clip(-1,1))

# Plot images
grid = create_img_grid(final_images, _nrow=5, _padding=1, plot=True)
