In [None]:
!git clone https://github.com/atong01/conditional-flow-matching.git

In [None]:
%cd conditional-flow-matching


In [None]:
!pip install -r requirements.txt


In [None]:
!pip install torchdiffeq
!pip install torch_optimizer

In [None]:
from typing import KeysView
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
import PIL.Image
from IPython.display import display
import os
import torch_optimizer
from google.colab import drive
drive.mount('/content/drive')

def save_output(path, _dict):
  torch.save(_dict, path)

def load_output(path):
  return torch.load(path, weights_only=False, map_location=device)

def denormalize_images(images):
  if isinstance(images, list):
    if images[0].ndim == 4:
      images = [img.squeeze(0) for img in images]
    return [img / 2 + 0.5 for img in images]
  return images / 2 + 0.5

def create_img_grid(images, _nrow=4, _padding=2, plot=False):
  grid = make_grid(denormalize_images(images), value_range=(0,1), nrow=_nrow, padding=_padding)
  if plot:
    img = ToPILImage()(grid)
    plt.imshow(img)
    plt.axis("off")
    plt.show()
  return grid


save_dir = '/content/drive/MyDrive/Mac pro dator googl drive/Universitet/KTH/DD2412/Project/CIFAR10_data'


testset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

)

classes = [0,2,3,5,9]
class_images = {class_id: [] for class_id in range(len(classes))}


for img, label in testset:
  if label in classes:
    class_images[classes.index(label)].append(img)

single_class_images = {key: images[1] for key, images in class_images.items()}

create_img_grid(list(single_class_images.values()), _nrow=1, _padding=2, plot=True)

subset_all_images = [img for images in class_images.values() for img in images[:10]]
create_img_grid(subset_all_images, _nrow=10, _padding=2, plot=True)


os.makedirs(save_dir, exist_ok=True)
class_images_path = os.path.join(save_dir, "class_images.pt")
single_class_images_path = os.path.join(save_dir, "single_class_images.pt")

for key in single_class_images.keys():
  print(key)
save_output(class_images_path, class_images)
save_output(single_class_images_path, single_class_images)



In [None]:
import os
import numpy as np
import tqdm
import yaml
from google.colab import drive
drive.mount('/content/drive')
import torch
import torchvision
from torchvision.transforms import GaussianBlur
from torchdyn.core import NeuralODE
import torchdiffeq
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
import PIL.Image
from IPython.display import display
from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models import MLP
from torchcfm.models.unet.unet import UNetModelWrapper

CONFIG = {
    'flow_model': 'otcfm',
    'model_variant': 'normal',
    'output_dir': './results/',
    'models_dir': '/content/drive/MyDrive/Mac pro dator googl drive/Universitet/KTH/DD2412/Project/flow_models',
    'in_channels': 3,
    'sample_size': 32,
    'class_cond': False,
    'num_classes': 10,
    'y_class': 0,
    'N': 20,
    'corrupt_Type': 'inpainting',
    'ODE_steps': 100,
    'lr': 1,
    'lr_list': [1,1,1,1,1],
    'lr_type': 'constant',
    'inpaint_percent': 0.9,
    'use_checkpointing': True,
    'frozen_model': False,
    'ODE_type': 'dopri5',
    'gamma': 1,
    'blend_param': 0.1,
    'FID_batch': 16,
    'FID_num_samples': 16,
    'num_images':5,
    'inpaint_param':3,
    'max_iter':5,
    'lr_start':1,
    'lr-end':0.4,
    'total_iter':30

}

device = "cuda" if torch.cuda.is_available() else "cpu"

x_0_list = []
image_list = []
losses = []


In [None]:
def save_output(path, _dict):
  torch.save(_dict, path)

def load_output(path, device):
  return torch.load(path, weights_only=False, map_location=device)



In [None]:
def save_config(config, save_dir):
  os.makedirs(save_dir, exist_ok=True)
  config_path = os.path.join(save_dir, 'config.yaml')
  with open(config_path, 'w') as f:
    yaml.dump(config, f)

  #images = np.load(os.path.join(output_dir, 'images.npy'))
  #losses = np.load(os.path.join(output_dir, 'losses.npy'))

def load_model(device):
  """
  Checkpoint contains:
    net_model
    ema_model
    sched
    optim
    step
  We only need net_model for inference, as the rest are for training the flow model.
  """
  # Create directory where all results will be stored
  save_dir = os.path.join(CONFIG['output_dir'], CONFIG['flow_model'])
  os.makedirs(save_dir, exist_ok=True)

  # Load checkpoint
  try:
    checkpoint_path = os.path.join(
      CONFIG['models_dir'], f"{CONFIG['flow_model']}_cifar10_weights_step_400000.pt"
    )
    print(f"Attempting to load {CONFIG['flow_model']} model from {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
  except NotImplementedError:
    raise NotImplementedError(
      f"Unknown model {CONFIG['flow_model']}, must be one of ['otcfm', 'icfm', 'fm']"
    )
  except FileNotFoundError:
    raise FileNotFoundError(
        f"Checkpoint file not found at {checkpoint_path}. Verify the path or the file's existence."
    )

  # Create UNet model and load in it's corresponding weights
  net_model = net_model = UNetModelWrapper(
    dim=(CONFIG['in_channels'], CONFIG['sample_size'], CONFIG['sample_size']),
    class_cond=CONFIG['class_cond'],
    num_classes=CONFIG['num_classes'],
    num_res_blocks=2,
    num_channels=128,
    channel_mult=[1, 2, 2, 2],
    num_heads=4,
    num_head_channels=64,
    attention_resolutions="16",
    dropout=0.1,
    use_checkpoint=CONFIG['use_checkpointing'],
  ).to(device)

  if CONFIG['model_variant'] == 'normal':
    net_model.load_state_dict(checkpoint['net_model'])
  elif CONFIG['model_variant'] == 'ema':
    net_model.load_state_dict(checkpoint['ema_model'])    # kanske borde loada från emea modellen istället?

  # For some reason the gradient computatins in our d-flow doesnt work when this is frozen, have to look into that
  # Now it seems to work? Odd xD
  if CONFIG['frozen_model']:
    for name, param in net_model.named_parameters():
      param.requires_grad = False

  print(f"Succesfully loaded the {CONFIG['model_variant']} {CONFIG['flow_model']} model onto the {device}.")
  if device == 'cpu': print(f"Consider switching device to GPU, as it will run out of RAM during D-flow otherwise.")
  return net_model

# test_model = load_model(device)

In [None]:
def subsample(mask, percentage=1):
  mask = mask.cpu()
  H, W = mask.shape
  mask_flat = mask.view(H * W)
  valid_indices = torch.nonzero(mask_flat).squeeze()
  sampled_indices = np.random.choice(valid_indices.numpy(), size=int(len(valid_indices) * percentage), replace=False)
  return sampled_indices

def mask_area(img, area=(slice(10, 14), slice(10, 14))):
  masked_img = torch.clone(img)
  mask = torch.ones(img.shape[-2], img.shape[-1]).to(device)
  mask[area] = 0
  masked_img = masked_img[:, :] * mask
  return masked_img, mask

def mask_center(img, width=2):
  masked_img = torch.clone(img)
  mask = torch.ones(img.shape[-2], img.shape[-1]).to(device)
  mask[int(img.shape[-2]/2-width):int(img.shape[-2]/2)+width, int(img.shape[-1]/2)-width:int(img.shape[-1]/2)+width] = 0
  masked_img = masked_img[:,:] * mask
  return masked_img, mask

def mask_half(img, width=3):
  masked_img = torch.clone(img)
  mask = torch.ones(img.shape[-2], img.shape[-1]).to(device)
  mask[:,-width:] = 0
  masked_img = masked_img[:,:] * mask
  return masked_img, mask

def mask_gaussian(img, mask_percent=0.2):
  mask = np.random.rand(img.shape[-2], img.shape[-1]) > mask_percent
  mask = torch.tensor(mask).to(device)
  masked_img = img[:,:] * mask
  return masked_img, mask

def mask_edge(img, width=2):
  img = img
  masked_img = torch.clone(img)
  mask = torch.ones(img.shape[-2], img.shape[-1]).to(device)
  mask[:, :width] = 0
  mask[:, -width:] = 0
  mask[:width, :] = 0
  mask[-width:, :] = 0
  masked_img = masked_img[:,:] * mask
  return masked_img, mask

def corrupt_image(image, corruption_type):
  if corruption_type == "mask_area":
    return mask_area(image, CONFIG['inpaint_param'])
  elif corruption_type == "mask_center":
    return mask_center(image, CONFIG['inpaint_param'])
  elif corruption_type == "mask_half":
    return mask_half(image, CONFIG['inpaint_param'])
  elif corruption_type == "mask_gaussian":
    return mask_gaussian(image, CONFIG['inpaint_param'])
  elif corruption_type == "mask_edge":
    return mask_edge(image, CONFIG['inpaint_param'])



def inpainting_loss(x,y,mask):
  B, C, H, W = x.shape

  # Flatten the image as the mask is flat
  x_flat, y_flat = x.view(B, C, H*W), y.view(B, C, H*W)

  # Retrieve the values at the sampled indices
  x_sampled_pixels = x_flat[:, :, mask]
  y_sampled_pixels = y_flat[:, :, mask]

  # Return the loss
  return torch.mean((x_sampled_pixels - y_sampled_pixels)**2)


def display_sample(sample, i):
  image_processed = sample.cpu().permute(0, 2, 3, 1)
  image_processed = (image_processed + 1.0) * 127.5
  image_processed = image_processed.numpy().astype(np.uint8)

  image_pil = PIL.Image.fromarray(image_processed[0])
  display(f"Image at step {i}")
  display(image_pil)

def plot_d_flow_process(x_1_list, img_per_row):
  test_images = []
  for i, images in enumerate(x_1_list):
    images = images[-1].view([3, 32, 32]).clip(-1, 1)
    images = images / 2 + 0.5
    test_images.append(images)

  grid = make_grid(test_images, value_range=(0, 1), nrow=img_per_row, padding=2)
  img = ToPILImage()(grid)
  plt.imshow(img)
  plt.axis("off")
  plt.show()

def plot_losses(losses):
  all_loses = []
  for sublist in losses:
    for loss in sublist:
      all_loses.append(loss.item())

  # Plot the loss curve
  plt.plot(all_loses)
  plt.xlabel("Iteration")
  plt.ylabel("Loss")
  plt.title(f"Loss vs Iteration")
  plt.show()

In [None]:
def generate_x_0(model, corrupted_y):
  x_0 = torch.randn(
   1, CONFIG['in_channels'], CONFIG['sample_size'], CONFIG['sample_size'], device=device, requires_grad=True
  )
  ODE_func = lambda t, x: model.forward(t, x)

  with torch.no_grad():
    x_0 = torchdiffeq.odeint(
      func=ODE_func,
      y0=corrupted_y,
      t=torch.linspace(1, 0, CONFIG['ODE_steps'], device=device),
      atol=1e-4,
      rtol=1e-4,
      method=CONFIG['ODE_type'],
    )[-1].to(device)

  x_0 = CONFIG['blend_param'] **0.5 * (corrupted_y + x_0) + (1. - CONFIG['blend_param'])**0.5*torch.randn_like(x_0)
  x_0.requires_grad=True
  return x_0

from torch.utils.tensorboard import SummaryWriter

def setup_experiment_dirs(base_dir, experiment_id):
  log_dir = os.path.join(base_dir, f'experiment_{experiment_id}/logs')
  output_dir = os.path.join(base_dir, f'experiment_{experiment_id}')
  os.makedirs(log_dir, exist_ok=True)
  os.makedirs(output_dir, exist_ok=True)
  return log_dir, output_dir

from PIL import Image
import torchvision.transforms.functional as TF

def save_PIL_images(path, output_dict):
  os.makedirs(path, exist_ok=True)
  D_flow_images = [output_dict[image]['x_1_list'][-1][-1] for image in output_dict.keys()]
  denormalized_D_flow_images = denormalize_images(D_flow_images)

  for idx, image_tensor in enumerate(denormalized_D_flow_images):
    pil_image = TF.to_pil_image(image_tensor.squeeze(0))

    image_file_path = os.path.join(path, f'd_flow_image_{idx}.jpg')
    pil_image.save(image_file_path)

def evaluate_d_flow(model, N, corruption_type='inpainting'):
  base_dir = '/content/drive/MyDrive/Mac pro dator googl drive/Universitet/KTH/DD2412/Project/Runs'
  os.makedirs(base_dir, exist_ok=True)

  experiment_id = len(os.listdir(base_dir)) + 2
  print(f"Experiment {experiment_id}")
  log_dir, output_dir = setup_experiment_dirs(base_dir, experiment_id)
  config_file = os.path.join(base_dir, f'experiment_{experiment_id}/config.yaml')

  save_config(CONFIG, config_file)

  writer = SummaryWriter(log_dir)

  output_dict = {}
  for i in range(CONFIG['num_images']):
    output_dict[i] = {}

    y = loaded_single_class_images[i].view(1,3,32,32).to(device)
    test_y = y.clone()
    test_y = test_y.view([3, 32, 32]).clip(-1, 1)
    output_dict[i]['y'] = test_y

    writer.add_image('Ground Truth', denormalize_images(test_y.squeeze(0)), i)

    x_0_list, image_list, losses = d_flow(model, y, N, writer, output_dict, corruption_type, img_num=i)

    grid = create_img_grid([images[-1].view([3,32,32]).clip(-1,1) for images in image_list+[y]], _nrow=3, _padding=2)

    writer.add_image('D-Flow Step Images', grid, i)

  save_output(os.path.join(output_dir, 'output_dict.pt'), output_dict)
  save_PIL_images(os.path.join(output_dir, 'd_flow_images'), output_dict)

def d_flow(model, y, N, writer, output_dict, corruption_type='inpainting', img_num=0):
  corrupted_y, mask = corrupt_image(y, corruption_type)
  corrupted_y = corrupted_y.to(device)
  subsamples = subsample(mask, 0.9)
  corrupted_y.requires_grad = False

  test_corrupted_y = corrupted_y.clone()
  test_corrupted_y = test_corrupted_y.view([3, 32, 32]).clip(-1, 1)

  output_dict[img_num]["Corrupted Y"] = test_corrupted_y
  writer.add_image('Corrupted Y', denormalize_images(test_corrupted_y), img_num)


  ODE_func = lambda t, x: model.forward(t, x)

  x_0 = generate_x_0(model, corrupted_y)

  optim = torch.optim.LBFGS(
    [x_0],
    lr=CONFIG['lr'],
    max_iter=CONFIG['max_iter'],
    history_size=20,
    tolerance_grad=1e-7,
    tolerance_change=1e-9,
    line_search_fn="strong_wolfe"
  )

  #optim = torch.optim.Adagrad([x_0], lr=CONFIG['lr'])
  """
  optim = torch_optimizer.Adahessian(
        [x_0],
        lr=CONFIG['lr'],
        betas=(0.9, 0.999),
        eps=1e-4,
        weight_decay=0.0,
        hessian_power=1.0,
    )
  """
  scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=CONFIG['gamma'])


  total_steps = 0

  x_0_list, image_list, losses = [], [], []

  closure_counter = 0

  for i in tqdm.tqdm(range(N)):
    iter_losses, iter_x_0s, iter_x_1s = [], [], []

    def closure():
      optim.zero_grad()

      x_1 = torchdiffeq.odeint(
        func=ODE_func,
        y0=x_0,
        t=torch.linspace(0, 1, CONFIG['ODE_steps'], device=device), #
        atol=1e-4,
        rtol=1e-4,
        method=CONFIG['ODE_type'],
      )[-1]

      loss = inpainting_loss(x_1, corrupted_y, subsamples)


      iter_losses.append(loss.item())
      iter_x_0s.append(x_0.detach())
      iter_x_1s.append(x_1.detach())

      #loss.backward(create_graph=True) for AdaHessian'
      loss.backward()


      torch.nn.utils.clip_grad_norm_([x_0], max_norm=1) #TODO

      return loss

    optim.step(closure)
    scheduler.step() #TODO

    #for param in [x_0]: #for AdaHessian
    #  param.grad = None #for AdaHessian

    losses.append(iter_losses)
    x_0_list.append(iter_x_0s)
    image_list.append(iter_x_1s)

    # Logging code
    grid = create_img_grid([image.view([3,32,32]).clip(-1,1) for image in iter_x_1s])

    writer.add_image(f'Closure Images/img_{img_num}', grid, closure_counter)
    closure_counter += 1

  flattened_data = [item for sublist in losses for item in sublist]

  plt.figure(figsize=(8, 6))

  plt.plot(range(len(flattened_data)), flattened_data, marker='o', label='Loss')

  plt.yscale('log')
  plt.xlabel('Steps')
  plt.ylabel('Loss')
  plt.legend()
  plt.grid(True)
  plt.tight_layout()
  plt.show()

  output_dict[img_num]['x_0_list'] = x_0_list
  output_dict[img_num]['x_1_list'] = image_list
  output_dict[img_num]['losses'] = losses
  return x_0_list, image_list, losses


In [None]:
net_model=load_model(device)

In [None]:
%load_ext tensorboard
%tensorboard --logdir '/content/drive/MyDrive/Mac pro dator googl drive/Universitet/KTH/DD2412/Project/Runs'


In [None]:
# Modify config as needed
net_model=load_model(device)
evaluate_d_flow(net_model, CONFIG['total_iter'], corruption_type=CONFIG['corrupt_Type'])