In [None]:
import os
import subprocess

def git_repo_root():
    # Run the 'git rev-parse --show-toplevel' command to get the root directory of the Git repository
    try:
        root = subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], universal_newlines=True).strip()
        return root
    except subprocess.CalledProcessError:
        # Handle the case where the current directory is not inside a Git repository
        return None

# Get the root directory of the Git repository
git_root = git_repo_root()

if git_root:
    # Change the working directory to the root of the Git repository
    os.chdir(git_root)
    print(f"Changed working directory to: {git_root}")
else:
    print("Not inside a Git repository.")

In [None]:
%load_ext autoreload
%autoreload 2

from diffusion import BridgeDiffusionVPSDE
from data import generate_mixture_gaussians
from data import generate_happy_face

# Make sure our diffusion process actually builds the bridge

data_x= generate_happy_face(500)
sde = BridgeDiffusionVPSDE(generate_mixture_gaussians,  bmin=.1, bmax=1)

sde.plot_forward_diffusion(data_x)
#For some reason the euler method fails at step 1000, idk

In [None]:
%load_ext autoreload
%autoreload 2

from torch.utils.data import DataLoader, TensorDataset
from training import train_score_network
from model import Bridge_Diffusion_Net

sde = BridgeDiffusionVPSDE(generate_mixture_gaussians,  bmin=.1, bmax=1)
data = generate_happy_face(num_samples=32000)
dataloader = DataLoader(TensorDataset(data), batch_size=500, shuffle=True)
score_net = Bridge_Diffusion_Net(input_dim=4, output_dim=2)

train_score_network(dataloader, score_net, sde, epochs=15000, bridge=True)

In [None]:
%load_ext autoreload
%autoreload 2
from diffusion import BridgeDiffusionVPSDE
import matplotlib.pyplot as plt

import math

def plot_x_snapshots(x_snapshots, labels=None):
    if not x_snapshots:
        print("No data to plot.")
        return

    num_plots = len(x_snapshots)
    n=9
    plot_intervals=[i / n for i in range(num_plots + 1)]
    num_cols = 5  # You can adjust this number based on your preference
    num_rows = math.ceil(num_plots / num_cols)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3), squeeze=False)

    for i, snapshot in enumerate(x_snapshots):
        row = i // num_cols
        col = i % num_cols
        ax = axes[row, col]
        ax.scatter(snapshot[:, 0], snapshot[:, 1])

    
        ax.set_title(f"T={(1-plot_intervals[i]):.2f} " )



    # Hide any unused subplots
    for j in range(i+1, num_rows*num_cols):
        axes.flat[j].set_visible(False)

    plt.tight_layout()
    plt.show()

sde = BridgeDiffusionVPSDE(generate_mixture_gaussians,  bmin=.1, bmax=1)
n=9
plot_intervals=[i / n for i in range(n + 1)]
plot_steps = [int(ts * 1000) for ts in plot_intervals]
plot_steps[-1]=plot_steps[-1]-1

x_snapshots=sde.backward_diffusion1(score_net, plot_steps=plot_steps )

plot_x_snapshots(x_snapshots, labels=None)
