# Score-based Generative Modeling with SDEs (Simple examples)

In [1]:
import os

'/Users/hyemin/Documents/source_code/Proximal_generative_models/scripts/sgm_simple-benjamin'

In [2]:
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
import lib.toy_data as toy_data
import numpy as np
import pickle

Basic parameters

In [3]:
learning_rate = 1e-3 # learning rate for training neural network
batch_size = 7160  # batch size during training of neural network
N_samples = 7160
epochs = 100000   # Number of training epochs for the neural network
sigma_max = 500.0 
sigma_min = 0.01  
T = 1    # Forward simulation time in the forward SDE (fixed)
dataset = 'Keystrokes' # Dataset choice, see toy_data for full options of toy datasets 
d = 1

We first initialize the neural net that models the score function. 

In [4]:
## Model construction

class DenoisingModel(nn.Module):
    
    def __init__(self, hidden_units=32):
        super(DenoisingModel, self).__init__()
        # hidden_units = 32
        
        # data and timestep
        self.fc1 = nn.Linear(d+1, int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc1.weight)
        self.activation1 = nn.GELU()
        self.fc2 = nn.Linear(int(hidden_units), int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc2.weight)
        self.activation2 = nn.GELU()
        self.fc3 = nn.Linear(int(hidden_units), d+1, bias=True)
        nn.init.xavier_uniform_(self.fc3.weight)
        self.activation3 = nn.GELU()
        
        self.fc4 = nn.Linear(d+1, int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc4.weight)
        self.activation4 = nn.GELU()
        self.fc5 = nn.Linear(int(hidden_units), int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc5.weight)
        self.activation5 = nn.GELU()
        self.fc6 = nn.Linear(int(hidden_units), d+1, bias=True)
        nn.init.xavier_uniform_(self.fc6.weight)
        self.activation6 = nn.GELU()
        
        self.fc7 = nn.Linear(d+1, int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc7.weight)
        self.activation7 = nn.GELU()
        self.fc8 = nn.Linear(int(hidden_units), int(hidden_units), bias=True)
        nn.init.xavier_uniform_(self.fc8.weight)
        self.activation8 = nn.GELU()
        self.fc9 = nn.Linear(int(hidden_units), d, bias=True)
        nn.init.xavier_uniform_(self.fc9.weight)
        
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.activation1(x)
        x = self.fc2(x)
        x = self.activation2(x)
        x = self.fc3(x)
        x = self.activation3(x)
        x = self.fc4(x)
        x = self.activation4(x)
        x = self.fc5(x)
        x = self.activation5(x)
        x = self.fc6(x)
        x = self.activation6(x)
        x = self.fc7(x)
        x = self.activation7(x)
        x = self.fc8(x)
        x = self.activation8(x)
        x = self.fc9(x)
        
        return x
      
scorenet = DenoisingModel()
print(scorenet)
optimizer = optim.Adam(scorenet.parameters(), lr=learning_rate)

DenoisingModel(
  (fc1): Linear(in_features=2, out_features=32, bias=True)
  (activation1): GELU(approximate='none')
  (fc2): Linear(in_features=32, out_features=32, bias=True)
  (activation2): GELU(approximate='none')
  (fc3): Linear(in_features=32, out_features=2, bias=True)
  (activation3): GELU(approximate='none')
  (fc4): Linear(in_features=2, out_features=32, bias=True)
  (activation4): GELU(approximate='none')
  (fc5): Linear(in_features=32, out_features=32, bias=True)
  (activation5): GELU(approximate='none')
  (fc6): Linear(in_features=32, out_features=2, bias=True)
  (activation6): GELU(approximate='none')
  (fc7): Linear(in_features=2, out_features=32, bias=True)
  (activation7): GELU(approximate='none')
  (fc8): Linear(in_features=32, out_features=32, bias=True)
  (activation8): GELU(approximate='none')
  (fc9): Linear(in_features=32, out_features=1, bias=True)
)


Define loss functions. These loss functions assume that the forward process is a standard OU process dx = -x/2 dt + dW. The choice of \lambda(t) in the SGM objective function is equal to 1 (the constant in front of the dW term). 

In [5]:
# Loss function -- we use the denoising diffusions objective function
# Scorenet is the score model, samples are the training samples, Tmin/Tmax are the time interval that is being trained on, and eps is so that Tmin is not sampled. 

def time_dsm_score_estimator(scorenet,samples,sigma_min,sigma_max,Tmin,Tmax,eps):

    t = torch.rand(samples.shape[0]) * (Tmax - Tmin - eps) + eps + Tmin # sample uniformly from time interval

    # Add noise to the training samples
    z = torch.randn_like(samples)
    sigmas = sigma_min * (sigma_max/sigma_min)**((t - Tmin)/(Tmax - Tmin))
    sigmas = sigmas.view(samples.shape[0],*([1]*len(samples.shape[1:])))
    noise = z * sigmas
    #tenlarge = t.repeat(2,1).T
    perturbed_samples = samples + noise

    # Evaluate score and marginal score on the perturbed samples
    #target = - 1/ (sigmas ** 2) * (noise)
    score_eval_samples = torch.cat((t.reshape(-1,1),perturbed_samples),1)
    scores = scorenet(score_eval_samples)

    # Evaluate the loss function 
    #target = target.view(target.shape[0],-1)
    scores = scores.view(scores.shape[0],-1)
    loss = 0.5 * ((scores * sigmas + z) ** 2).sum(dim = -1) 

    return loss.mean(dim = 0)


# Loss function
# This is for if you have a specific mesh for the time interval you would like the network to train on. 
def deterministic_time_dsm_score_estimator(scorenet,samples,t):

    loss = 0
    for ii in range(len(t)-1):

        # Add noise to the training samples
        sigmas = torch.sqrt(1 - torch.exp(-t[ii]))
        noise = torch.randn_like(samples) * sigmas
        perturbed_samples = samples * torch.exp(-0.5 * t[ii]) + noise

        # Evaluate score and marginal score on perturbed samples
        target = - 1/ (sigmas ** 2) * (noise)
        score_eval_samples = torch.cat((t[ii].repeat(perturbed_samples.shape[0],1),perturbed_samples),1)
        scores = scorenet(score_eval_samples)

        # Evaluate loss function at this particular t[ii]
        target = target.view(target.shape[0],-1)
        scores = scores.view(scores.shape[0],-1)
        loss_vec = 0.5 * ((scores-target) ** 2).sum(dim = -1) 
        loss = loss + (t[ii+1]-t[ii])*loss_vec.mean(dim = 0)

    return loss


Training the score network

In [6]:
# Training the score network
# sample toy_data
#p_samples_all = toy_data.inf_train_gen(dataset, N_samples)
p_samples = toy_data.inf_train_gen(dataset, N_samples)

for step in range(epochs):
    #for i in range(int(N_samples//batch_size)):
    #    p_samples = p_samples_all[i*batch_size:(i+1)*batch_size,:]
    samples = torch.tensor(p_samples).to(dtype = torch.float32)

    # evaluate loss function and gradient
    loss = time_dsm_score_estimator(scorenet,samples,sigma_min, sigma_max, 0,T,eps = 0.0005)
    optimizer.zero_grad()
    loss.backward()

    # Update score network
    optimizer.step()

    if not step%100:
        print(loss,step)




tensor(29254.3242, grad_fn=<MeanBackward1>) 0
tensor(1.1315, grad_fn=<MeanBackward1>) 100
tensor(0.5438, grad_fn=<MeanBackward1>) 200
tensor(0.5124, grad_fn=<MeanBackward1>) 300
tensor(0.4768, grad_fn=<MeanBackward1>) 400
tensor(0.9859, grad_fn=<MeanBackward1>) 500
tensor(0.4715, grad_fn=<MeanBackward1>) 600
tensor(0.4680, grad_fn=<MeanBackward1>) 700
tensor(0.4754, grad_fn=<MeanBackward1>) 800
tensor(0.4758, grad_fn=<MeanBackward1>) 900
tensor(0.4635, grad_fn=<MeanBackward1>) 1000
tensor(0.4663, grad_fn=<MeanBackward1>) 1100
tensor(0.4801, grad_fn=<MeanBackward1>) 1200
tensor(0.4829, grad_fn=<MeanBackward1>) 1300
tensor(0.4697, grad_fn=<MeanBackward1>) 1400
tensor(0.4733, grad_fn=<MeanBackward1>) 1500
tensor(0.4775, grad_fn=<MeanBackward1>) 1600
tensor(0.4632, grad_fn=<MeanBackward1>) 1700
tensor(0.4745, grad_fn=<MeanBackward1>) 1800
tensor(0.4643, grad_fn=<MeanBackward1>) 1900
tensor(0.4657, grad_fn=<MeanBackward1>) 2000
tensor(0.4571, grad_fn=<MeanBackward1>) 2100
tensor(0.4557, gra

tensor(0.4740, grad_fn=<MeanBackward1>) 18100
tensor(0.4529, grad_fn=<MeanBackward1>) 18200
tensor(0.4612, grad_fn=<MeanBackward1>) 18300
tensor(0.4753, grad_fn=<MeanBackward1>) 18400
tensor(0.4615, grad_fn=<MeanBackward1>) 18500
tensor(0.4706, grad_fn=<MeanBackward1>) 18600
tensor(0.4589, grad_fn=<MeanBackward1>) 18700
tensor(0.4624, grad_fn=<MeanBackward1>) 18800
tensor(0.4543, grad_fn=<MeanBackward1>) 18900
tensor(0.4609, grad_fn=<MeanBackward1>) 19000
tensor(0.4596, grad_fn=<MeanBackward1>) 19100
tensor(0.4644, grad_fn=<MeanBackward1>) 19200
tensor(0.4860, grad_fn=<MeanBackward1>) 19300
tensor(0.4649, grad_fn=<MeanBackward1>) 19400
tensor(0.4658, grad_fn=<MeanBackward1>) 19500
tensor(0.4673, grad_fn=<MeanBackward1>) 19600
tensor(0.4628, grad_fn=<MeanBackward1>) 19700
tensor(0.4706, grad_fn=<MeanBackward1>) 19800
tensor(0.4687, grad_fn=<MeanBackward1>) 19900
tensor(0.4714, grad_fn=<MeanBackward1>) 20000
tensor(0.4535, grad_fn=<MeanBackward1>) 20100
tensor(0.4766, grad_fn=<MeanBackwa

tensor(0.4725, grad_fn=<MeanBackward1>) 36000
tensor(0.4620, grad_fn=<MeanBackward1>) 36100
tensor(0.4564, grad_fn=<MeanBackward1>) 36200
tensor(0.4834, grad_fn=<MeanBackward1>) 36300
tensor(0.4584, grad_fn=<MeanBackward1>) 36400
tensor(0.4542, grad_fn=<MeanBackward1>) 36500
tensor(0.4570, grad_fn=<MeanBackward1>) 36600
tensor(0.4606, grad_fn=<MeanBackward1>) 36700
tensor(0.4631, grad_fn=<MeanBackward1>) 36800
tensor(0.4636, grad_fn=<MeanBackward1>) 36900
tensor(0.4902, grad_fn=<MeanBackward1>) 37000
tensor(0.4670, grad_fn=<MeanBackward1>) 37100
tensor(0.4565, grad_fn=<MeanBackward1>) 37200
tensor(0.4644, grad_fn=<MeanBackward1>) 37300
tensor(0.4573, grad_fn=<MeanBackward1>) 37400
tensor(0.4829, grad_fn=<MeanBackward1>) 37500
tensor(0.4557, grad_fn=<MeanBackward1>) 37600
tensor(0.4589, grad_fn=<MeanBackward1>) 37700
tensor(0.4554, grad_fn=<MeanBackward1>) 37800
tensor(0.4824, grad_fn=<MeanBackward1>) 37900
tensor(0.4994, grad_fn=<MeanBackward1>) 38000
tensor(0.5117, grad_fn=<MeanBackwa

tensor(0.4990, grad_fn=<MeanBackward1>) 53900
tensor(0.4956, grad_fn=<MeanBackward1>) 54000
tensor(0.5105, grad_fn=<MeanBackward1>) 54100
tensor(0.4998, grad_fn=<MeanBackward1>) 54200
tensor(0.4991, grad_fn=<MeanBackward1>) 54300
tensor(0.5075, grad_fn=<MeanBackward1>) 54400
tensor(0.5073, grad_fn=<MeanBackward1>) 54500
tensor(0.5051, grad_fn=<MeanBackward1>) 54600
tensor(0.5172, grad_fn=<MeanBackward1>) 54700
tensor(0.5113, grad_fn=<MeanBackward1>) 54800
tensor(0.4988, grad_fn=<MeanBackward1>) 54900
tensor(0.4952, grad_fn=<MeanBackward1>) 55000
tensor(0.5093, grad_fn=<MeanBackward1>) 55100
tensor(0.4981, grad_fn=<MeanBackward1>) 55200
tensor(0.5055, grad_fn=<MeanBackward1>) 55300
tensor(0.5087, grad_fn=<MeanBackward1>) 55400
tensor(0.4964, grad_fn=<MeanBackward1>) 55500
tensor(0.4824, grad_fn=<MeanBackward1>) 55600
tensor(0.5106, grad_fn=<MeanBackward1>) 55700
tensor(0.5045, grad_fn=<MeanBackward1>) 55800
tensor(0.4890, grad_fn=<MeanBackward1>) 55900
tensor(0.5088, grad_fn=<MeanBackwa

tensor(0.5199, grad_fn=<MeanBackward1>) 71800
tensor(0.5072, grad_fn=<MeanBackward1>) 71900
tensor(0.5053, grad_fn=<MeanBackward1>) 72000
tensor(0.5128, grad_fn=<MeanBackward1>) 72100
tensor(0.4975, grad_fn=<MeanBackward1>) 72200
tensor(0.4951, grad_fn=<MeanBackward1>) 72300
tensor(0.4980, grad_fn=<MeanBackward1>) 72400
tensor(0.5056, grad_fn=<MeanBackward1>) 72500
tensor(0.5019, grad_fn=<MeanBackward1>) 72600
tensor(0.5059, grad_fn=<MeanBackward1>) 72700
tensor(0.4988, grad_fn=<MeanBackward1>) 72800
tensor(0.4896, grad_fn=<MeanBackward1>) 72900
tensor(0.5008, grad_fn=<MeanBackward1>) 73000
tensor(0.5009, grad_fn=<MeanBackward1>) 73100
tensor(0.4903, grad_fn=<MeanBackward1>) 73200
tensor(0.5057, grad_fn=<MeanBackward1>) 73300
tensor(0.5016, grad_fn=<MeanBackward1>) 73400
tensor(0.5097, grad_fn=<MeanBackward1>) 73500
tensor(0.4910, grad_fn=<MeanBackward1>) 73600
tensor(0.5063, grad_fn=<MeanBackward1>) 73700
tensor(0.5078, grad_fn=<MeanBackward1>) 73800
tensor(0.4957, grad_fn=<MeanBackwa

tensor(0.4882, grad_fn=<MeanBackward1>) 89700
tensor(0.5121, grad_fn=<MeanBackward1>) 89800
tensor(0.4973, grad_fn=<MeanBackward1>) 89900
tensor(0.4972, grad_fn=<MeanBackward1>) 90000
tensor(0.5082, grad_fn=<MeanBackward1>) 90100
tensor(0.5003, grad_fn=<MeanBackward1>) 90200
tensor(0.5121, grad_fn=<MeanBackward1>) 90300
tensor(0.5087, grad_fn=<MeanBackward1>) 90400
tensor(0.4919, grad_fn=<MeanBackward1>) 90500
tensor(0.4831, grad_fn=<MeanBackward1>) 90600
tensor(0.4957, grad_fn=<MeanBackward1>) 90700
tensor(0.4883, grad_fn=<MeanBackward1>) 90800
tensor(0.5213, grad_fn=<MeanBackward1>) 90900
tensor(0.5005, grad_fn=<MeanBackward1>) 91000
tensor(0.5054, grad_fn=<MeanBackward1>) 91100
tensor(0.5067, grad_fn=<MeanBackward1>) 91200
tensor(0.4703, grad_fn=<MeanBackward1>) 91300
tensor(0.4382, grad_fn=<MeanBackward1>) 91400
tensor(0.4545, grad_fn=<MeanBackward1>) 91500
tensor(0.4708, grad_fn=<MeanBackward1>) 91600
tensor(0.4724, grad_fn=<MeanBackward1>) 91700
tensor(0.4654, grad_fn=<MeanBackwa

SDE simulation functions

In [7]:
# This is the solving the VE process exactly given deterministic initial conditions
def ve_dynamics(init, sigma_min, sigma_max, T=1.0):
    init = init + sigma_max * torch.randn_like(init)
    return init
    
def reverse_sde(score, init, sigma_min, sigma_max, T=1.0, lr=0.01):
    step = int(T/lr)
    for i in range(step,-1,-1):
        evalpoint = torch.cat(((torch.tensor(lr*i/T)).repeat(init.shape[0],1),init),1)
        drift = torch.zeros_like(evalpoint)
        sigma = sigma_min *(sigma_max/sigma_min)**(i*lr/T)
        diffusion = sigma * torch.sqrt(2*(torch.log(sigma_max)-torch.log(sigma_min)))
        
        init = init + lr  * diffusion**2 * score(evalpoint).detach() 
        init = init + diffusion * torch.randn_like(init) * np.sqrt(lr)
    return init



# The following is the deterministic ODE flow that can also sample from the target distribution

def reverse_ode_flow(score,init, sigma_min, sigma_max, T=1.0, lr = 0.01):
    step = int(T/lr)
    for i in range(step,-1,-1):
        evalpoint = torch.cat(((torch.tensor(lr*i/T)).repeat(init.shape[0],1),init),1)
        sigma = sigma_min *(sigma_max/sigma_min)**(i*lr/T)
        diffusion = sigma * torch.sqrt(2*(torch.log(sigma_max)-torch.log(sigma_min)))
        
        init = init + lr/2* diffusion**2 * score(evalpoint).detach() 
    return init

Sample using the score network 

In [8]:
# Denoising the normal distribution 
samples_lang = torch.randn(10000, 1)*sigma_max # * (right_bound - left_bound) + left_bound
samples_lang = reverse_sde(scorenet, samples_lang, torch.tensor(sigma_min), torch.tensor(sigma_max)).detach().numpy()


In [9]:
#plt.clf()
p_samples = toy_data.inf_train_gen(dataset, batch_size = 7160)
ve_samples = ve_dynamics(torch.tensor(p_samples).to(dtype = torch.float32),  torch.tensor(sigma_min), torch.tensor(sigma_max))


"\nbinsize=500\nplt.hist(ve_samples, label='generated', alpha=0.9,cumulative=-1, density=True, bins=binsize)\nplt.hist(p_samples, label='true',alpha=0.5,cumulative=-1, density=True, bins=binsize)\nplt.axis('square')\nplt.title('Samples from VE process')\nplt.legend()\nplt.show()\n"

In [10]:
filename = f"../../assets/{dataset}/ve_sgm_{N_samples}samples.pickle"
with open(filename,"wb") as fw:
    pickle.dump([ve_samples, samples_lang] , fw)