In [None]:
import torch
from math import comb, pow, log2
import torch.optim as optim
from torch.distributions import Binomial

from datetime import datetime
import os
import json
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation

In [None]:
def get_next():
    global index_permutation
    global cur_ind
    global data_batch_size
    global data
    
    if cur_ind < num_data:
        indices = index_permutation[cur_ind : cur_ind + data_batch_size]
        cur_ind += data_batch_size
        return data[indices]

    cur_ind = 0
    return None


def calculate_beta_tilde_t(beta_tilde_t, cur_t):
    a = beta_tilde_t[cur_t-1] + beta_t(cur_t)
    b = beta_t(cur_t) * beta_tilde_t[cur_t-1]
    return (a - b)

def validate(model, num_samples, batch_size):
    train_data = []
    with open(os.path.join('train.txt'), 'r') as f:
        for line in f:
            train_data.append(torch.FloatTensor([float(x) for x in line.strip()]).to("cpu"))
        train_data = torch.vstack(train_data)

    train_data = train_data.cpu().detach().numpy().tolist()

    val_data = []
    with open(os.path.join('val.txt'), 'r') as f:
        for line in f:
            val_data.append(torch.FloatTensor([float(x) for x in line.strip()]).to("cpu"))
        val_data = torch.vstack(val_data)

    val_data = val_data.cpu().detach().numpy().tolist()

    train_count = 0
    val_count = 0
    other_count = 0

    sample_count = 0
    while sample_count < num_samples:
        samples = model.p_sample(batch_size).cpu().detach().numpy().tolist()

        for sample in samples:
            if sample in train_data:
                train_count += 1
            elif sample in val_data:
                val_count += 1
            else:
                other_count += 1
            sample_count += 1
            if sample_count >= num_samples:
                break

    return train_count, val_count, other_count

def plot_evolution(epochs, examples, step_name='Step', filename='evolution.gif'):
    fig = plt.figure()
    im = plt.imshow(examples[0], interpolation='none')

    def init():
        fig.suptitle(step_name + ': 0')
        im.set_data(examples[0])
        return [im]

    def animate(i):
        fig.suptitle(step_name + ': {}'.format(epochs[i]))
        im.set_array(examples[i])
        return [im]

    # generate the animation
    ani = FuncAnimation(fig, animate, init_func=init,
                        frames=len(examples), interval=300, repeat=True) 
    
    ani.save( filename, writer='imagemagick', fps=2)

    fig.clf()

In [None]:
T =  2000
num_sample_steps = 30
epochs = 100
lr = 0.5
save_every_n_epochs = 10
num_examples = 10
num_val_samples = 1024
val_batch_size = 64
clip_thresh = 1.0
data_batch_size = 64


cur_ind = 0


data = []
with open(os.path.join('train.txt'), 'r') as f:
    for line in f:
        data.append(torch.FloatTensor([float(x) for x in line.strip()]).to("cpu"))

data = torch.vstack(data)
num_data = data.size()[0]
sequence_length = data.size()[1]
index_permutation = torch.randperm(num_data)

# def get_next():
#     global index_permutation
#     global cur_ind
#     global data_batch_size
#     global data
    
#     if cur_ind < num_data:
#         indices = index_permutation[cur_ind : cur_ind + data_batch_size]
#         cur_ind += data_batch_size
#         return data[indices]

#     cur_ind = 0
#     return None


# def calculate_beta_tilde_t(beta_tilde_t, cur_t):
#     a = beta_tilde_t[cur_t-1] + beta_t(cur_t)
#     b = beta_t(cur_t) * beta_tilde_t[cur_t-1]
#     return (a - b)

# def validate(model, num_samples, batch_size):
#     train_data = []
#     with open(os.path.join('train.txt'), 'r') as f:
#         for line in f:
#             train_data.append(torch.FloatTensor([float(x) for x in line.strip()]).to("cpu"))
#         train_data = torch.vstack(train_data)

#     train_data = train_data.cpu().detach().numpy().tolist()

#     val_data = []
#     with open(os.path.join('val.txt'), 'r') as f:
#         for line in f:
#             val_data.append(torch.FloatTensor([float(x) for x in line.strip()]).to("cpu"))
#         val_data = torch.vstack(val_data)

#     val_data = val_data.cpu().detach().numpy().tolist()

#     train_count = 0
#     val_count = 0
#     other_count = 0

#     sample_count = 0
#     while sample_count < num_samples:
#         samples = model.p_sample(batch_size).cpu().detach().numpy().tolist()

#         for sample in samples:
#             if sample in train_data:
#                 train_count += 1
#             elif sample in val_data:
#                 val_count += 1
#             else:
#                 other_count += 1
#             sample_count += 1
#             if sample_count >= num_samples:
#                 break

#     return train_count, val_count, other_count

# def plot_evolution(epochs, examples, step_name='Step', filename='evolution.gif'):
#     '''Given a sequence of batches of samples, this animates their evolution.
#     Useful for showing how a particular sample changes during training.
#     Also useful for illustrating the reverse process.'''
#     fig = plt.figure()
#     im = plt.imshow(examples[0], interpolation='none')

#     def init():
#         fig.suptitle(step_name + ': 0')
#         im.set_data(examples[0])
#         return [im]

#     def animate(i):
#         fig.suptitle(step_name + ': {}'.format(epochs[i]))
#         im.set_array(examples[i])
#         return [im]

#     # generate the animation
#     ani = FuncAnimation(fig, animate, init_func=init,
#                         frames=len(examples), interval=300, repeat=True) 
    
#     ani.save( filename, writer='imagemagick', fps=2)

#     fig.clf()


diffusion_model = DiffusionModel(sequence_length, num_sample_steps, T)
optimizer = optim.SGD(diffusion_model.parameters(), lr=lr)

diffusion_model.zero_grad()

m = Binomial(1, torch.zeros((data_batch_size, sequence_length)).fill_(0.5))
random_data = m.sample().to("cpu")

out = diffusion_model(random_data)
out.backward()


losses = []
examples_per_epoch = []
proportions = {'train': [], 'val':[], 'other':[]}
random_seed = Binomial(1, torch.zeros((num_examples, sequence_length)).fill_(0.5)).sample().to("cpu")

for epoch in range(0, epochs):
    avg_loss = 0.0
    batch_count = 0
    batch = get_next()
    while batch is not None:
        optimizer.zero_grad()
        output = diffusion_model(batch)
        if torch.isnan(output):
            raise Exception('learning rate  is too high')
        output.backward()
        torch.nn.utils.clip_grad_norm_(diffusion_model.parameters(), clip_thresh)
        optimizer.step()
        batch = get_next()
        batch_count += 1
        avg_loss += output.detach()
    avg_loss = avg_loss / batch_count
    losses.append(avg_loss.item())

    sample_data = diffusion_model.p_sample(num_examples, random_seed)
    examples_per_epoch.append(sample_data)

    train_count, val_count, other_count = validate(diffusion_model,num_val_samples,val_batch_size)

    proportions['train'].append(train_count/num_val_samples)
    proportions['val'].append(val_count/num_val_samples)
    proportions['other'].append(other_count/num_val_samples)

    print('Epoch: {}: Loss: {} Train Proportion: {} Val Proportion: {}'.format(epoch,losses[-1],proportions['train'][-1],proportions['val'][-1]))

results = {'losses': losses,
            'examples_per_epoch': [x.tolist() for x in examples_per_epoch],
            'proportions': proportions}
with open('results.json', 'w') as fp:
    json.dump(results, fp)

epochs_plotted = [x for x in range(0, epochs)]

plt.plot(epochs_plotted, losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.savefig(f"loss_curve_{lr}_{epochs}_SGD.png")
plt.clf()

numpy_examples = [x.cpu().detach().numpy() for x in examples_per_epoch]
plot_evolution(epochs_plotted,
            numpy_examples,
            step_name = 'Epoch',
            filename=f'sample_evolution_throughout_training_{lr}_{epochs}_SGD.gif')



for key, item in proportions.items():
    plt.plot(epochs_plotted, item, label=key)
plt.title('Validation Proportions')
plt.xlabel('Epoch')
plt.ylabel('Proportion')
plt.legend()
plt.savefig(f'validation_{lr}_{epochs}_SGD.png')
plt.clf()

Epoch: 0: Loss: 126772.9765625 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 1: Loss: 126681.5859375 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 2: Loss: 126618.96875 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 3: Loss: 126550.9921875 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 4: Loss: 126484.3984375 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 5: Loss: 126420.9296875 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 6: Loss: 126351.1328125 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 7: Loss: 126283.8828125 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 8: Loss: 126210.8359375 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 9: Loss: 126144.0859375 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 10: Loss: 126065.7734375 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 11: Loss: 125996.4765625 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 12: Loss: 125923.625 Train Proportion: 0.0 Val Proportion: 0.0
Epoch: 13: Loss: 125840.7421875 Train Proportion: 0.0 



<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

In [None]:
import torch
from torch import nn
from torch.distributions import Binomial
from math import comb, pow, log2


class DiffusionModel(nn.Module):
    def __init__(self, sequence_length, num_sample_steps, T):
        super().__init__()

        self.model = Model(sequence_length, T).to("cpu")
        
        self.sequence_length = sequence_length
        self.num_sample_steps = num_sample_steps

        self.final_noise = None
        self.T = T
        self.beta_tilde_t = [torch.zeros((self.sequence_length)).to("cpu")]
        for cur_t in range(1, self.T+1):
            self.beta_tilde_t.append((self.beta_tilde_t[cur_t-1] + self.beta_t(cur_t) - (self.beta_t(cur_t)*self.beta_tilde_t[cur_t-1])))
        self.beta_tilde_t = torch.stack(self.beta_tilde_t)

        self.H_start = self.entropy_of_q_conditional(self.sequence_length, self.beta_tilde_t[1,0].item())
        self.H_end = self.entropy_of_q_conditional(self.sequence_length, self.beta_tilde_t[self.T, 0].item())
        self.H_prior = self.entropy_of_prior(self.sequence_length)

    @torch.no_grad()
    def entropy_of_q_conditional(self, sequence_length, beta_tilde_t):
        total_entropy = 0.0
        eps = 1e-30
        for k in range(0, sequence_length+1):
            n_c_k = comb(sequence_length, k)
            prob = pow((1.0-(0.5*beta_tilde_t)), k) * pow(0.5*beta_tilde_t, sequence_length-k)
            cur_entropy = n_c_k * prob * log2(prob + eps)
            total_entropy += cur_entropy
        return -1.0 * total_entropy

    def entropy_of_prior(self, sequence_length):
        '''Assuming all Bernoulli distributions in prior have prob 0.5.
        Fun fact: this basically just returns float(sequence_length)'''
        eps = 1e-30
        total_entropy = 0.0
        for k in range(0, sequence_length+1):
            n_choose_k = comb(sequence_length, k)
            prob = pow((1.0-0.5), k) * pow(0.5, sequence_length-k)
            cur_entropy = n_choose_k * prob * log2(prob + eps)
            total_entropy += cur_entropy
        return -1.0 * total_entropy

    def p_conditional_prob(self, x_t, t):
        return self.model(x_t, t)

    @torch.no_grad()
    def p_step(self, x, t):
        return torch.bernoulli(self.model(x, t))

    @torch.no_grad()
    def p_sample(self, batch_size, x=None):
        if x is None:
            init_prob = torch.empty((batch_size, self.sequence_length)).fill_(0.5).to("cpu")
            x = torch.bernoulli(init_prob)
        else:
            assert batch_size == x.size(dim=0)

        for cur_t in range(self.T, 0, -1):
            x = torch.bernoulli(self.p_conditional_prob(x, cur_t))
        return x

    def beta_t(self, t):
        return 1.0/(self.T-t+1)

    def kl_div(self, q, p):
        eps = 1e-30
        '''KL Divergence of two multivariate Bernoulli distributions'''
        q = torch.clip(q, min=1e-10, max=1-(1e-7))
        p = torch.clip(p, min=1e-10, max=1-(1e-7))
        return torch.sum((q * torch.log2((q/p) + eps)) + ((1.0-q) * torch.log2(((1.0-q)/(1.0-p)) + eps)), dim=1)


    def q_conditional_prob(self, x_t, t):
        # had to change the beta_t equation here to keep indexing consistent
        return (x_t * (1.0 - self.beta_t(t+1))) + 0.5 * self.beta_t(t+1)

    def q_conditional_prob_wrt_x_0(self, x_0, t):
        beta_tilde_t = self.beta_tilde_t[t].expand(x_0.size())
        return ((x_0 * (1.0 - beta_tilde_t)) + 0.5 * beta_tilde_t)

    @torch.no_grad()
    def q_step(self, x, t):
        probs = self.q_conditional_prob(x, t)
        return torch.bernoulli(probs)

    def q_sample(self, x_0, t):
        return torch.bernoulli(self.q_conditional_prob_wrt_x_0(x_0, t))
    
    def forward(self, x_0):
        # the monte carlo sampling is performed using the minibatch
        total_loss = torch.zeros((x_0.size(dim=0),)).to("cpu")
        for t in range(1, self.T + 1):
            x_t = self.q_sample(x_0, t)
            beta_t = self.beta_t(t)
            posterior = x_0*(1-self.beta_tilde_t[t-1]) + 0.5*self.beta_tilde_t[t-1]
            posterior *= x_t * (1-0.5*beta_t) + (1 - x_t) * (0.5*beta_t)
            normalizing_constant = x_t * self.q_conditional_prob_wrt_x_0(x_0, t) + (1-x_t) * (1-self.q_conditional_prob_wrt_x_0(x_0, t))
            posterior = posterior / normalizing_constant
            kl_divergence = self.kl_div(posterior,self.p_conditional_prob(x_t, t))

            total_loss += kl_divergence + self.H_start - self.H_end + self.H_prior


            if(t == self.T):
                self.final_noise = x_t

        return torch.mean(total_loss)


In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self, sequence_length, T):
        super().__init__()
        print(sequence_length)
        self.shared_layers = torch.nn.Sequential(
            torch.nn.Linear(sequence_length, 70),
            torch.nn.Sigmoid(),
            torch.nn.Linear(70, 70),
            torch.nn.Sigmoid()
        )

        self.output_layers = nn.ModuleList([torch.nn.Linear(70, sequence_length) for x in range(T)])
        self.outputs_sigmoid = torch.nn.Sigmoid()

    def forward(self, x, t):
        x = self.shared_layers(x)
        x = self.output_layers[t-1](x)
        x = self.outputs_sigmoid(x)
        return x