# Score-based Generative Modeling with SDEs (Simple examples)

In [1]:
import os

In [2]:
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
import lib.toy_data as toy_data
import numpy as np
import pickle

Basic parameters

In [3]:
learning_rate = 1e-3 # learning rate for training neural network
batch_size = 10000  # batch size during training of neural network
N_samples = 10000
epochs = 100000   # Number of training epochs for the neural network
sigma_max = 100.0 # 100 for df=1.0
sigma_min = 0.01  # 0.01 for df=1.0
T = 5    # Forward simulation time in the forward SDE (fixed)
dataset = 'Heavytail_submanifold' # Dataset choice, see toy_data for full options of toy datasets
d = 10
d_orth=100

We first initialize the neural net that models the score function. 

In [4]:
## Model construction

class DenoisingModel(nn.Module):
    
    def __init__(self, hidden_units=128):
        super(DenoisingModel, self).__init__()
        # hidden_units = 32
        
        # data and timestep
        self.fc1 = nn.Linear(d+d_orth+1, int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc1.weight)
        self.activation1 = nn.GELU()
        self.fc2 = nn.Linear(int(hidden_units), int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc2.weight)
        self.activation2 = nn.GELU()
        self.fc3 = nn.Linear(int(hidden_units), d+d_orth+1, bias=True)
        nn.init.xavier_uniform_(self.fc3.weight)
        self.activation3 = nn.GELU()
        
        self.fc4 = nn.Linear(d+d_orth+1, int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc4.weight)
        self.activation4 = nn.GELU()
        self.fc5 = nn.Linear(int(hidden_units), int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc5.weight)
        self.activation5 = nn.GELU()
        self.fc6 = nn.Linear(int(hidden_units), d+d_orth+1, bias=True)
        nn.init.xavier_uniform_(self.fc6.weight)
        self.activation6 = nn.GELU()
        
        self.fc7 = nn.Linear(d+d_orth+1, int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc7.weight)
        self.activation7 = nn.GELU()
        self.fc8 = nn.Linear(int(hidden_units), int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc8.weight)
        self.activation8 = nn.GELU()
        self.fc9 = nn.Linear(int(hidden_units), d+d_orth, bias=True)
        nn.init.xavier_uniform_(self.fc9.weight)
        
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.activation1(x)
        x = self.fc2(x)
        x = self.activation2(x)
        x = self.fc3(x)
        x = self.activation3(x)
        x = self.fc4(x)
        x = self.activation4(x)
        x = self.fc5(x)
        x = self.activation5(x)
        x = self.fc6(x)
        x = self.activation6(x)
        x = self.fc7(x)
        x = self.activation7(x)
        x = self.fc8(x)
        x = self.activation8(x)
        x = self.fc9(x)
        
        return x
      
scorenet = DenoisingModel()
print(scorenet)
optimizer = optim.Adam(scorenet.parameters(), lr=learning_rate)

DenoisingModel(
  (fc1): Linear(in_features=111, out_features=256, bias=True)
  (activation1): GELU(approximate='none')
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (activation2): GELU(approximate='none')
  (fc3): Linear(in_features=256, out_features=111, bias=True)
  (activation3): GELU(approximate='none')
  (fc4): Linear(in_features=111, out_features=256, bias=True)
  (activation4): GELU(approximate='none')
  (fc5): Linear(in_features=256, out_features=256, bias=True)
  (activation5): GELU(approximate='none')
  (fc6): Linear(in_features=256, out_features=111, bias=True)
  (activation6): GELU(approximate='none')
  (fc7): Linear(in_features=111, out_features=256, bias=True)
  (activation7): GELU(approximate='none')
  (fc8): Linear(in_features=256, out_features=256, bias=True)
  (activation8): GELU(approximate='none')
  (fc9): Linear(in_features=256, out_features=110, bias=True)
)


Define loss functions. These loss functions assume that the forward process is a variance exploding process $dx =  + \sigma_t dW$. The choice of $\lambda(t)$ in the SGM objective function is equal to 1 (the constant in front of the dW term). 

In [5]:
# Loss function -- we use the denoising diffusions objective function
# Scorenet is the score model, samples are the training samples, Tmin/Tmax are the time interval that is being trained on, and eps is so that Tmin is not sampled. 

def time_dsm_score_estimator(scorenet,samples,sigma_min,sigma_max,Tmin,Tmax,eps):

    t = torch.rand(samples.shape[0]) * (Tmax - Tmin - eps) + eps + Tmin # sample uniformly from time interval

    # Add noise to the training samples
    z = torch.randn_like(samples)
    sigmas = sigma_min * (sigma_max/sigma_min)**((t - Tmin)/(Tmax - Tmin))
    sigmas = sigmas.view(samples.shape[0],*([1]*len(samples.shape[1:])))
    noise = z * sigmas
    #tenlarge = t.repeat(2,1).T
    perturbed_samples = samples + noise

    # Evaluate score and marginal score on the perturbed samples
    #target = - 1/ (sigmas ** 2) * (noise)
    score_eval_samples = torch.cat((t.reshape(-1,1),perturbed_samples),1)
    scores = scorenet(score_eval_samples)

    # Evaluate the loss function 
    #target = target.view(target.shape[0],-1)
    scores = scores.view(scores.shape[0],-1)
    loss = 0.5 * ((scores * sigmas + z) ** 2).sum(dim = -1) 

    return loss.mean(dim = 0)


# Loss function
# This is for if you have a specific mesh for the time interval you would like the network to train on. 
def deterministic_time_dsm_score_estimator(scorenet,samples,t):

    loss = 0
    for ii in range(len(t)-1):

        # Add noise to the training samples
        sigmas = torch.sqrt(1 - torch.exp(-t[ii]))
        noise = torch.randn_like(samples) * sigmas
        perturbed_samples = samples * torch.exp(-0.5 * t[ii]) + noise

        # Evaluate score and marginal score on perturbed samples
        target = - 1/ (sigmas ** 2) * (noise)
        score_eval_samples = torch.cat((t[ii].repeat(perturbed_samples.shape[0],1),perturbed_samples),1)
        scores = scorenet(score_eval_samples)

        # Evaluate loss function at this particular t[ii]
        target = target.view(target.shape[0],-1)
        scores = scores.view(scores.shape[0],-1)
        loss_vec = 0.5 * ((scores-target) ** 2).sum(dim = -1) 
        loss = loss + (t[ii+1]-t[ii])*loss_vec.mean(dim = 0)

    return loss


Training the score network

In [6]:
# Training the score network
# sample toy_data
#p_samples_all = toy_data.inf_train_gen(dataset, N_samples)
p_samples = toy_data.inf_train_gen(dataset, N_samples)

for step in range(epochs):
    #for i in range(int(N_samples//batch_size)):
    #    p_samples = p_samples_all[i*batch_size:(i+1)*batch_size,:]
    samples = torch.tensor(p_samples).to(dtype = torch.float32)

    # evaluate loss function and gradient
    loss = time_dsm_score_estimator(scorenet,samples,sigma_min, sigma_max, 0,T,eps = 0.0005)
    optimizer.zero_grad()
    loss.backward()

    # Update score network
    optimizer.step()

    if not step%100:
        print(loss,step)




tensor(1.7372e+10, grad_fn=<MeanBackward1>) 0
tensor(113.7866, grad_fn=<MeanBackward1>) 100
tensor(57.8126, grad_fn=<MeanBackward1>) 200
tensor(55.4715, grad_fn=<MeanBackward1>) 300
tensor(55.3222, grad_fn=<MeanBackward1>) 400
tensor(55.1756, grad_fn=<MeanBackward1>) 500
tensor(55.1019, grad_fn=<MeanBackward1>) 600
tensor(54.9758, grad_fn=<MeanBackward1>) 700
tensor(55.0227, grad_fn=<MeanBackward1>) 800
tensor(55.0504, grad_fn=<MeanBackward1>) 900
tensor(55.0073, grad_fn=<MeanBackward1>) 1000
tensor(55.0374, grad_fn=<MeanBackward1>) 1100
tensor(55.0821, grad_fn=<MeanBackward1>) 1200
tensor(55.0198, grad_fn=<MeanBackward1>) 1300
tensor(55.1877, grad_fn=<MeanBackward1>) 1400
tensor(54.9707, grad_fn=<MeanBackward1>) 1500
tensor(54.9946, grad_fn=<MeanBackward1>) 1600
tensor(55.0662, grad_fn=<MeanBackward1>) 1700
tensor(54.8866, grad_fn=<MeanBackward1>) 1800
tensor(54.9048, grad_fn=<MeanBackward1>) 1900
tensor(54.9530, grad_fn=<MeanBackward1>) 2000
tensor(54.9644, grad_fn=<MeanBackward1>) 2

tensor(137.3148, grad_fn=<MeanBackward1>) 34400
tensor(130.4629, grad_fn=<MeanBackward1>) 34500
tensor(135.0703, grad_fn=<MeanBackward1>) 34600
tensor(129.0948, grad_fn=<MeanBackward1>) 34700
tensor(141.4388, grad_fn=<MeanBackward1>) 34800
tensor(145.9889, grad_fn=<MeanBackward1>) 34900
tensor(148.0789, grad_fn=<MeanBackward1>) 35000
tensor(150.3133, grad_fn=<MeanBackward1>) 35100
tensor(151.0993, grad_fn=<MeanBackward1>) 35200
tensor(146.2246, grad_fn=<MeanBackward1>) 35300
tensor(153.7743, grad_fn=<MeanBackward1>) 35400
tensor(144.2014, grad_fn=<MeanBackward1>) 35500
tensor(146.3279, grad_fn=<MeanBackward1>) 35600
tensor(150.0573, grad_fn=<MeanBackward1>) 35700
tensor(147.2516, grad_fn=<MeanBackward1>) 35800
tensor(144.0405, grad_fn=<MeanBackward1>) 35900
tensor(141.4045, grad_fn=<MeanBackward1>) 36000
tensor(146.5808, grad_fn=<MeanBackward1>) 36100
tensor(146.4994, grad_fn=<MeanBackward1>) 36200
tensor(142.8770, grad_fn=<MeanBackward1>) 36300
tensor(144.8176, grad_fn=<MeanBackward1>

tensor(97.3478, grad_fn=<MeanBackward1>) 51800
tensor(95.1120, grad_fn=<MeanBackward1>) 51900
tensor(93.8167, grad_fn=<MeanBackward1>) 52000
tensor(99.5109, grad_fn=<MeanBackward1>) 52100
tensor(96.1338, grad_fn=<MeanBackward1>) 52200
tensor(94.1536, grad_fn=<MeanBackward1>) 52300
tensor(97.1968, grad_fn=<MeanBackward1>) 52400
tensor(95.5906, grad_fn=<MeanBackward1>) 52500
tensor(94.7018, grad_fn=<MeanBackward1>) 52600
tensor(97.6129, grad_fn=<MeanBackward1>) 52700
tensor(96.4438, grad_fn=<MeanBackward1>) 52800
tensor(96.0964, grad_fn=<MeanBackward1>) 52900
tensor(95.4565, grad_fn=<MeanBackward1>) 53000
tensor(96.5955, grad_fn=<MeanBackward1>) 53100
tensor(93.1324, grad_fn=<MeanBackward1>) 53200
tensor(94.9028, grad_fn=<MeanBackward1>) 53300
tensor(95.1940, grad_fn=<MeanBackward1>) 53400
tensor(95.8240, grad_fn=<MeanBackward1>) 53500
tensor(96.1851, grad_fn=<MeanBackward1>) 53600
tensor(94.8066, grad_fn=<MeanBackward1>) 53700
tensor(97.4835, grad_fn=<MeanBackward1>) 53800
tensor(97.985

tensor(129.6468, grad_fn=<MeanBackward1>) 69000
tensor(130.9536, grad_fn=<MeanBackward1>) 69100
tensor(132.2050, grad_fn=<MeanBackward1>) 69200
tensor(125.4845, grad_fn=<MeanBackward1>) 69300
tensor(127.5276, grad_fn=<MeanBackward1>) 69400
tensor(133.4069, grad_fn=<MeanBackward1>) 69500
tensor(128.1409, grad_fn=<MeanBackward1>) 69600
tensor(130.9486, grad_fn=<MeanBackward1>) 69700
tensor(129.4337, grad_fn=<MeanBackward1>) 69800
tensor(129.4167, grad_fn=<MeanBackward1>) 69900
tensor(126.6541, grad_fn=<MeanBackward1>) 70000
tensor(129.5190, grad_fn=<MeanBackward1>) 70100
tensor(126.3294, grad_fn=<MeanBackward1>) 70200
tensor(125.7597, grad_fn=<MeanBackward1>) 70300
tensor(124.9708, grad_fn=<MeanBackward1>) 70400
tensor(127.7591, grad_fn=<MeanBackward1>) 70500
tensor(126.7276, grad_fn=<MeanBackward1>) 70600
tensor(127.8017, grad_fn=<MeanBackward1>) 70700
tensor(123.9154, grad_fn=<MeanBackward1>) 70800
tensor(128.6561, grad_fn=<MeanBackward1>) 70900
tensor(127.2157, grad_fn=<MeanBackward1>

tensor(4.6276e+24, grad_fn=<MeanBackward1>) 85900
tensor(4.4444e+25, grad_fn=<MeanBackward1>) 86000
tensor(8.1156e+26, grad_fn=<MeanBackward1>) 86100
tensor(5.7200e+26, grad_fn=<MeanBackward1>) 86200
tensor(5.6219e+22, grad_fn=<MeanBackward1>) 86300
tensor(3.0163e+20, grad_fn=<MeanBackward1>) 86400
tensor(9.5646e+25, grad_fn=<MeanBackward1>) 86500
tensor(8.7808e+26, grad_fn=<MeanBackward1>) 86600
tensor(2.4252e+26, grad_fn=<MeanBackward1>) 86700
tensor(1.3752e+25, grad_fn=<MeanBackward1>) 86800
tensor(3.4258e+24, grad_fn=<MeanBackward1>) 86900
tensor(4.6189e+24, grad_fn=<MeanBackward1>) 87000
tensor(4.4419e+25, grad_fn=<MeanBackward1>) 87100
tensor(1.3381e+25, grad_fn=<MeanBackward1>) 87200
tensor(2.3037e+24, grad_fn=<MeanBackward1>) 87300
tensor(2.2429e+24, grad_fn=<MeanBackward1>) 87400
tensor(1.6087e+25, grad_fn=<MeanBackward1>) 87500
tensor(1.9650e+23, grad_fn=<MeanBackward1>) 87600
tensor(4.9525e+23, grad_fn=<MeanBackward1>) 87700
tensor(1.9354e+24, grad_fn=<MeanBackward1>) 87800


SDE simulation functions

In [8]:
# This is the solving the VE process exactly given deterministic initial conditions
def ve_dynamics(init, sigma_min, sigma_max, T=1.0):
    init = init + sigma_max * torch.randn_like(init)
    return init
    
def reverse_sde(score, init, sigma_min, sigma_max, T=1.0, lr=0.01):
    step = int(T/lr)
    for i in range(step,-1,-1):
        evalpoint = torch.cat(((torch.tensor(lr*i/T)).repeat(init.shape[0],1),init),1)
        drift = torch.zeros_like(evalpoint)
        sigma = sigma_min *(sigma_max/sigma_min)**(i*lr/T)
        diffusion = sigma * torch.sqrt(2*(torch.log(sigma_max)-torch.log(sigma_min)))
        
        init = init + lr  * diffusion**2 * score(evalpoint).detach() 
        init = init + diffusion * torch.randn_like(init) * np.sqrt(lr)
    return init



# The following is the deterministic ODE flow that can also sample from the target distribution

def reverse_ode_flow(score,init, sigma_min, sigma_max, T=1.0, lr = 0.01):
    step = int(T/lr)
    for i in range(step,-1,-1):
        evalpoint = torch.cat(((torch.tensor(lr*i/T)).repeat(init.shape[0],1),init),1)
        sigma = sigma_min *(sigma_max/sigma_min)**(i*lr/T)
        diffusion = sigma * torch.sqrt(2*(torch.log(sigma_max)-torch.log(sigma_min)))
        
        init = init + lr/2* diffusion**2 * score(evalpoint).detach() 
    return init

Sample using the score network 

In [9]:
# Denoising the normal distribution 
samples_lang = torch.randn(10000, d+d_orth)*sigma_max # * (right_bound - left_bound) + left_bound
samples_lang = reverse_sde(scorenet, samples_lang, torch.tensor(sigma_min), torch.tensor(sigma_max)).detach().numpy()


p_samples = toy_data.inf_train_gen(dataset, batch_size = 10000)
ve_samples = ve_dynamics(torch.tensor(p_samples).to(dtype = torch.float32),  torch.tensor(sigma_min), torch.tensor(sigma_max))


In [10]:
plt.clf()
p_samples = toy_data.inf_train_gen(dataset, batch_size = 10000, misc_params={})
ve_samples = ve_dynamics(torch.tensor(p_samples).to(dtype = torch.float32),  torch.tensor(sigma_min), torch.tensor(sigma_max))
plt.scatter(ve_samples[:,0],ve_samples[:,1],s = 0.1)
plt.axis('square')
plt.title('Samples from VE process')
plt.show()


plt.clf()
p_samples = toy_data.inf_train_gen(dataset, batch_size = 10000, misc_params={})
samples_true = torch.tensor(p_samples).to(dtype = torch.float32)
plt.scatter(samples_true[:,0],samples_true[:,1],s = 0.1)
plt.axis('square')
plt.title('True samples')
plt.show()


plt.clf()
plt.scatter(samples_lang[:,0],samples_lang[:,1],s = 0.1)
plt.axis('square')
plt.title('Samples from reverse SDE')
plt.show()




NameError: name 'df' is not defined

<Figure size 640x480 with 0 Axes>

In [11]:
filename = f"../../assets/{dataset}/ve_sgm_heavytail_submanifold_{N_samples}samples.pickle"
with open(filename,"wb") as fw:
    pickle.dump([ve_samples, samples_lang] , fw)