In [1]:
import numpy as np
import torch
from diffusionmodel import *
from torch.optim import Adam
import bgflow.distribution.sampling.mcmc as MCMC
import bgflow.distribution.energy.double_well as DoubleWell
import bgflow.distribution.normal as Normal
import matplotlib.pyplot as plt


torch.random.manual_seed(199)

device = 'cuda'
T = 100


In [None]:
target = DoubleWell.DoubleWellEnergy(dim=1, b=-4., c=1.)
prior = Normal.NormalDistribution(dim=1)

net = torch.nn.Sequential(
    torch.nn.Linear(2, 64), 
    torch.nn.SiLU(), 
    torch.nn.Linear(64, 128),  
    torch.nn.SiLU(), 
    torch.nn.Linear(128, 64),  
    torch.nn.SiLU(), 
    torch.nn.Linear(64, 1))


ts = torch.linspace(0, T, 50)
xs = torch.linspace(-3, 3, 50)

X, Y = torch.meshgrid(xs, ts)

In [None]:
sampler = MCMC.GaussianMCMCSampler(energy=target, init_state=torch.tensor([0.]))

from utils import load_or_generate_and_then_save
datafilepath = 'double_well.npy'
data = torch.from_numpy(load_or_generate_and_then_save(datafilepath, lambda : sampler.sample(n_samples=50000)))

# plot histogram of the sampled data
counts, bins = np.histogram(data, bins=xs, density=True)
# plt.plot(bins, torch.exp(-target.energy(torch.tensor(bins).unsqueeze_(1))))
plt.stairs(counts, bins, fill=True)
plt.show()

In [None]:
beta_schedule = torch.linspace(1e-4, 0.05, T)
diff_model = DiffusionModel(net=net, variance_schedule=beta_schedule, device=device)
print(sum([len(p) for p in diff_model.parameters()]))

In [None]:
z = torch.zeros_like(X)
for i, t in enumerate(range(0, T, 2)):
    t_s=torch.full(data.shape, t).to(device)
    x_t = diff_model.apply_noise(x_0=data.to(device), t_s=t_s)[0]
    z[:, i] = torch.histogram(x_t.cpu(), bins=torch.cat([xs.cpu(), torch.tensor([6.])]), density=True)[0]


im1=plt.contourf(X.cpu(), Y.cpu(), z.cpu(), keepdim=True, levels=np.linspace(0, 0.5,50))
plt.colorbar(im1)

In [None]:
batch_size=128

train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size)
optimizer = Adam(diff_model.parameters(), lr=1e-3)
scheduler =torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, cooldown=3, verbose=True)

from train import train
def callback(model):
      samples = diff_model.sample([1000, 1])

      # plot histogram of the sampled data
      counts, bins = np.histogram(samples.cpu(), bins=xs, density=True)
      # plt.plot(bins, torch.exp(-target.energy(torch.tensor(bins).unsqueeze_(1))))
      plt.stairs(counts, bins, fill=True)
      plt.show()


train(diff_model, 
      loss_fn=torch.nn.MSELoss(), 
      optimizer=optimizer, 
      data_loader=train_loader, 
      scheduler=scheduler, 
      n_iterations=50, 
      device=device, 
      callback_interval=5,
      callback=callback)


In [None]:
samples = diff_model.sample([10000, 1])

# plot histogram of the sampled data
counts, bins = np.histogram(samples.cpu(), bins=xs, density=True)
# plt.plot(bins, torch.exp(-target.energy(torch.tensor(bins).unsqueeze_(1))))
plt.stairs(counts, bins, fill=True, alpha=0.5, label="Model samples")

# plot histogram of the sampled data
counts, bins = np.histogram(data, bins=xs, density=True)
# plt.plot(bins, torch.exp(-target.energy(torch.tensor(bins).unsqueeze_(1))))
plt.stairs(counts, bins, fill=True, alpha=0.5, label="data samples")
plt.legend()