<a href="https://colab.research.google.com/github/aparey/Guided-flow-match/blob/main/Flow_Matching_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
####Connect Google Drive####
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
####Importing Libraries####
import sys
sys.path.append('/content/drive/MyDrive/Colab_Notebooks/CS_682/')

In [None]:
!pip install torchdiffeq
!pip install torchmetrics
!pip install torchviz
!pip install torch-fidelity

In [None]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
from torchmetrics.functional.multimodal import clip_score
from torchmetrics.image.fid import FrechetInceptionDistance
import pickle
from functools import partial
import torchviz
from torchviz import make_dot
import PIL
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tqdm
from torchdiffeq import odeint
import os

from caption_generation import CIFAR10WithCaptions
from unet_attn import UNet
from text_encoding import reshape_text

In [None]:
####Helper Functions####
def convert(loss_vals):
  new_loss= []
  for i in loss_vals:
    new_loss.append(i)
  return new_loss

In [None]:
####Load the dataset####
def uncond_dataset():
  transform = transforms.Compose(
    [transforms.ToTensor()])

  train_dataset = torchvision.datasets.CIFAR10(root='./drive/MyDrive/CS 682/CS682 Project/uncond_dataset/train/', train=True,
                                          download=True, transform=transform)
  test_dataset = torchvision.datasets.CIFAR10(root='./drive/MyDrive/CS 682/CS682 Project/uncond_dataset/train/', train=False,
                                          download=True, transform=transform)

  class Custom_CIFAR_train(torch.utils.data.Dataset):
    def __init__(self, train_dataset):
      self.target_imgs = train_dataset
    def __getitem__(self, idx):
      return self.target_imgs[idx][0]
    def __len__(self):
      return len(train_dataset)

  flow_train_dataset = Custom_CIFAR_train(train_dataset)
  flow_test_dataset = Custom_CIFAR_train(test_dataset)

  return flow_train_dataset, flow_test_dataset

In [None]:
def cond_dataset():
  with open('/content/drive/MyDrive/Colab_Notebooks/CS_682/Data/CAPTIONED_CIFAR_TRAIN.pkl', 'rb') as file:
    flow_train_dataset = pickle.load(file)
  with open('/content/drive/MyDrive/Colab_Notebooks/CS_682/Data/CAPTIONED_CIFAR_TEST.pkl', 'rb') as file:
    flow_test_dataset = pickle.load(file)
  return flow_train_dataset, flow_test_dataset

In [None]:
def load_dataset(conditional_gen=False):
  if conditional_gen:
    return cond_dataset()
  else:
    return uncond_dataset()

In [None]:
####Sampling####
def sample_from_dataset(dataset, conditional_dataset=False):
  idx = torch.randint(0, len(dataset))
  if conditional_dataset:
    c_img, label, caption = dataset[idx]
  else:
    g_img, c_img = dataset[idx]
  plt.imshow(g_img.permute(1,2,0))
  plt.imshow(c_img.permute(1,2,0))

In [None]:
####Train-Val Split####
def split_train_dataset(flow_train_dataset, train_frac):
  flow_train_dataset, flow_val_dataset = torch.utils.data.random_split(flow_train_dataset, [train_frac, 1-train_frac])
  return flow_train_dataset, flow_val_dataset

In [None]:
####Create DataLoaders####
def gen_loaders(dataset, batch_size):
  return torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=True,num_workers = 6)

In [None]:
####Model Initialisation####
def model_init(convnet = True, conditional_gen=False):
  if conditional_gen:
    return UNet(conditional_gen = conditional_gen)
  else:
    return UNet(conditional_gen = conditional_gen)

In [None]:
####Masking Function####
def masking_tokens(tokens, pad_token=0):
  p = torch.rand(1)
  mask = torch.zeros(tokens.size())
  if p<0.1:
    tokens = tokens * mask
  return tokens

In [None]:
####Loss Function####
def loss(vf_flow, x_1, t, reshape_text, tokens = None, conditional_gen = False):

  x_0 = torch.rand(x_1.shape).to("cuda")

  xt = t[:, None, None, None]*x_1 + (1-t[:,None, None, None])*x_0
  xt = xt.cuda()

  true_flow = x_1 - x_0
  if conditional_gen:
    tokens = reshape_text(tokens)
    tokens = masking_tokens(tokens).to('cuda')
  if conditional_gen:
     predicted_flow, _ = vf_flow(t, (xt, tokens))
  else:
    predicted_flow  = vf_flow(t, xt)
  flow_objective_loss = torch.sum((predicted_flow - true_flow)**2, axis=(1,2,3))
  avg_obj_loss = torch.mean(flow_objective_loss)
  return avg_obj_loss

In [None]:
####Training Loop####
def train_one_epoch(min_delta, patience, epochs,
                    flow_train_dataset,
                    lr, batch_size, reshape_text,
                    convnet = True, conditional_gen=False,
                    epoch_print = 1, infer_num = 0, epoch_save = 250, final_infer_num = 3):

    loss_vals = []
    infer_imgs = []
    num_iter = -1
    flow_train_loader = gen_loaders(flow_train_dataset, batch_size)
    vf_flow = model_init(convnet, conditional_gen)
    vf_flow.to('cuda')
    optimizer = torch.optim.Adam(vf_flow.parameters(), lr = lr)

    for epoch in range(epochs):
      if (epoch % epoch_save == 0 and epoch != 0):
        plot_loss(loss_vals)
        torch.save(vf_flow, f'/content/drive/MyDrive/Colab_Notebooks/CS_682/Models/Conditional/Intermediate/conditional_unet_honey_ham_epoch_{epoch}.pth')
      if (epoch % epoch_print == 0 and epoch!=0):
        print(f'Epoch: {epoch} \nLoss: {loss_vals[-1]}')
        print('-------------------')
      for i, data in enumerate(flow_train_loader):
        num_iter+=1
        stop = True
        t = torch.rand(len(data[0])).to('cuda')

        if conditional_gen:
          x_1, labels, captions, x_0 = data
          loss_val = loss(vf_flow, x_1.to('cuda'), t, reshape_text.to('cuda'), captions, conditional_gen=True)
        else:
          x_0, x_1 = data
          loss_val = loss(vf_flow, x_1.to('cuda'), t, reshape_text.to('cuda'))

        loss_vals.append(loss_val.item())

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        loss_val.backward()

        # Adjust learning weights
        optimizer.step()

    torch.cuda.empty_cache()

    final_images = []
    torch.save(vf_flow, '/content/drive/MyDrive/Colab_Notebooks/CS_682/Models/Conditional/Final/conditional_unet_honey_ham_final.pth')
    for infers in range(final_infer_num):
          idx = torch.randint(0, len(flow_train_dataset), (1,)).item()
          if conditional_gen:
            img = inference(vf_flow, caption = flow_train_dataset[idx][2], reshape_text=reshape_text, conditional_gen=True).to('cpu').permute(1,2,0)
          else:
            img = inference(vf_flow, conditional_gen=False).to('cpu').permute(1,2,0)
          plt.imshow(img)
          plt.show()
          final_images.append(img)
    plot_loss(loss_vals)
    return loss_vals, infer_imgs, final_images, vf_flow

In [None]:
####Model Vizualization####
def model_viz_print(model):
  print(model)

def model_viz_graph(model, conditional_gen=False):
  if conditional_gen:
    x = torch.rand(2,4,32,32)
    t = torch.FloatTensor([0.0,1.0])
  else:
    x = torch.rand(2,3,32,32)
    t = torch.FloatTensor([0.0,1.0])
  make_dot(model(t,x))

In [None]:
####ODE Solver for Inference####
def inference(model, caption=None, reshape_text=None, conditional_gen=False):
  x_0 = torch.rand(1,3,32,32).to('cuda')
  if conditional_gen:
    tokens = reshape_text([caption]).to('cuda')

 # t = torch.linspace(0.0,1.0,10) # To observe the change in the image
  t = torch.tensor([0.0, 1.0]).to('cuda')

  with torch.no_grad():
      if conditional_gen:
        x_1, _ = odeint(model, (x_0, tokens), t, method='dopri5', atol=1e-5, rtol=1e-5)
      else:
        x_1 = odeint(model, x_0, t, method='dopri5', atol=1e-5, rtol=1e-5)
  return x_1[-1,0]

In [None]:
####Loss Curve Plot####
def plot_loss(loss_vals):
  plt.xlabel('#iterations')
  plt.ylabel('Loss')
  plt.plot(list(range(len(loss_vals))), convert(loss_vals))
  plt.show()

In [None]:
####Unconditional Flow Matching####
flow_train_dataset, flow_test_dataset = load_dataset(False)
flow_train_dataset, flow_val_dataset = split_train_dataset(flow_train_dataset, 0.98)
ds_size = 30

In [None]:
from mpl_toolkits.axes_grid1 import ImageGrid
import math
imgs = []

for img in range(ds_size):
  imgs.append(flow_train_dataset[img][1].permute(1,2,0))

  ncols = 10
  nrows = math.ceil(ds_size / ncols)


fig = plt.figure(figsize=(10., 4.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(nrows, ncols),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, im in zip(grid, imgs):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

plt.show()

In [None]:
min_delta = 0
patience = 10
epochs = 2000
lr = 1e-4
batch_size = 128
epoch_print = 1
epoch_save = 6
infer_num = 0
final_infer_num = 10
conditional_gen = False
convnet = True


print(f'train for <{epochs}> epochs')
print(f'learning rate is <{lr}>')
print(f'batch size is <{batch_size}>')
print(f'conditional generation is set to <{conditional_gen}>')
print(f'print loss every <{epoch_print}> epochs')
print(f'generate <{infer_num}> images every <{epoch_print}> epochs')
print(f'save intermediate models every <{epoch_save}> epochs')
print(f'generate <{final_infer_num}> images after training')

In [None]:
loss_vals, infer_imgs,final_infers, model = train_one_epoch(min_delta, patience, epochs,
                                                            flow_train_dataset,
                                                            lr, batch_size, reshape_text,
                                                            convnet=convnet, conditional_gen=conditional_gen,
                                                            epoch_print=epoch_print, epoch_save=epoch_save, infer_num = infer_num,
                                                            final_infer_num=final_infer_num)

In [None]:
from matplotlib.pyplot import set_loglevel
import math

plot_loss(loss_vals)

total_infers = len(infer_imgs)
grid_cols = 10
grid_rows = math.ceil(total_infers / grid_cols)

plt.set_loglevel("critical")
fig = plt.figure(figsize=(grid_cols, grid_rows))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(grid_rows, grid_cols),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, im in zip(grid, infer_imgs):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

plt.show()


final_infer_num = 3
grid_rows = math.ceil(final_infer_num / grid_cols)
fig = plt.figure(figsize=(grid_cols, grid_rows))

grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(grid_rows, grid_cols),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, im in zip(grid, final_infers):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

plt.show()
plt.set_loglevel("warning")

In [None]:
####Conditional Flow Matching####
flow_train_dataset, flow_test_dataset = load_dataset(True)
flow_train_dataset, flow_val_dataset = split_train_dataset(flow_train_dataset, 0.98)
ds_size = 30

In [None]:
from mpl_toolkits.axes_grid1 import ImageGrid
imgs = []

for img in range(ds_size):
  imgs.append(flow_train_dataset[img][0].permute(1,2,0))


fig = plt.figure(figsize=(10., 4.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(3, 10),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, im in zip(grid, imgs):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

plt.show()

In [None]:
for i in range(ds_size):
  print(flow_train_dataset[i][2])

In [None]:
min_delta = 0
patience = 10
epochs = 2500
lr = 3e-4
batch_size = 128
epoch_print = 1
epoch_save = 6
infer_num = 0
final_infer_num = 10
conditional_gen = True
convnet = True


print(f'train for <{epochs}> epochs')
print(f'learning rate is <{lr}>')
print(f'batch size is <{batch_size}>')
print(f'conditional generation is set to <{conditional_gen}>')
print(f'print loss every <{epoch_print}> epochs')
print(f'generate <{infer_num}> images every <{epoch_print}> epochs')
print(f'save intermediate models every <{epoch_save}> epochs')
print(f'generate <{final_infer_num}> images after training')

In [None]:
loss_vals, infer_imgs,final_infers, model = train_one_epoch(min_delta, patience, epochs,
                                                            flow_train_dataset,
                                                            lr, batch_size, reshape_text,
                                                            convnet=convnet, conditional_gen=conditional_gen,
                                                            epoch_print=epoch_print, epoch_save=epoch_save, infer_num = infer_num,
                                                            final_infer_num=final_infer_num)

In [None]:
from matplotlib.pyplot import set_loglevel
import math

plot_loss(loss_vals)

total_infers = len(infer_imgs)
grid_cols = 10
grid_rows = math.ceil(total_infers / grid_cols)

plt.set_loglevel("critical")
fig = plt.figure(figsize=(grid_cols, grid_rows))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(grid_rows, grid_cols),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, im in zip(grid, infer_imgs):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

plt.show()


final_infer_num = 3
grid_rows = math.ceil(final_infer_num / grid_cols)
fig = plt.figure(figsize=(grid_cols, grid_rows))

grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(grid_rows, grid_cols),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, im in zip(grid, final_infers):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

plt.show()
plt.set_loglevel("warning")