# MNIST Digit Addition
This notebook provides information on how to generate images within SLASH

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

img = mpimg.imread('data/generative_overview.png')
plt.axis('off')
plt.imshow(img)
plt.show()

Qualitative comparison of Generative MNIST-Addition: each row is entailing images per class. The first one are ground truths obtained from the dataset, the second one represents the sampled images from a PC trained solely in a generative way. The last two rows are sampled from the NPP trained within SLASH. The third row depicts images after an epoch of the discriminative learning step, and the last row after a one epoch of generative step.


### Generative training in SLASH

In [None]:
import train

seed = 1
drop_out = 0.0


exp_name = f'pc-generative-poon-domingos-normal-seed-{seed}-epochs-100-pd-7'
exp_dict = {'structure':'poon-domingos', 'pd_num_pieces':[7],
                'use_spn':True, 'credentials':'DO', 'seed':seed, 'learn_prior':True,
                'lr':0.01, 'bs':100, 'epochs':30, 'p_num':8, 'drop_out':0.0}
    

In [None]:
print("Experiment's folder is %s" % exp_name)
train.slash_mnist_addition(exp_name, exp_dict)

#### show experiment results

In [None]:
import torch
sys.path.append('../../')
sys.path.append('../../SLASH/')
sys.path.append('../../EinsumNetworks/src/')
from einsum_wrapper import EiNet

EPOCH = 1

#use probabilisitc circuit
#setup new SLASH program given the network parameters
m = EiNet(structure = exp_dict['structure'],
                pd_num_pieces = exp_dict['pd_num_pieces'],
                use_em = False,
                num_var = 784,
                class_count = 10,
                #K = 40,
                #num_sums = 40,
                pd_width = 28,
                pd_height = 28,
                learn_prior = exp_dict['learn_prior'])



print("loading model after EM STEP")
saved_model = torch.load("data/"+exp_name+"/slash_digit_addition_models_generative_e"+str(EPOCH)+".pt")
m.load_state_dict(saved_model['addition_net'])

fig, axs = plt.subplots(10,10)

plt.rcParams['figure.constrained_layout.use'] = True
fig.set_figheight(10)
fig.set_figwidth(10)

#for each class
for c in range(0,10):
    samples = m.sample(num_samples=10, class_idx=c).cpu().numpy()
    samples = samples.reshape((-1, 28, 28))
    for i,s in enumerate(samples):
        axs[i,c].axis('off')
        axs[i,c].imshow(s, cmap='gray')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

print("loading model after SLASH STEP")
saved_model = torch.load("data/"+exp_name+"/slash_digit_addition_models_slash_e"+str(EPOCH)+".pt")
m.load_state_dict(saved_model['addition_net'])

fig, axs = plt.subplots(10,10)

plt.rcParams['figure.constrained_layout.use'] = True
fig.set_figheight(10)
fig.set_figwidth(10)

#for each class
for c in range(0,10):
    samples = m.sample(num_samples=10, class_idx=c).cpu().numpy()
    samples = samples.reshape((-1, 28, 28))
    for i,s in enumerate(samples):
        axs[i,c].axis('off')
        axs[i,c].imshow(s, cmap='gray')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

