
# **MNIST - Denoising Diffusion Probabilistic Model**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Deyht/AI_astro_ED_AAIF/blob/main/codes/CNN/diffusion/MNIST_diffusion.ipynb)

---


### **CIANNA BETA DEV installation**

/!\ WARNING /!\
This beta version is not suited for general application and has been modified for the specific case covered in this notebook. Some function might have a different behavior than the expected one.
Do not use outside this notebook !

#### Query GPU allocation and properties

If nvidia-smi fail, it might indicate that you launched the colab session whithout GPU reservation.  
To change the type of reservation go to "Runtime"->"Change runtime type" and select "GPU" as your hardware accelerator.

In [None]:
%%shell

nvidia-smi

cd /content/

git clone https://github.com/NVIDIA/cuda-samples/

cd /content/cuda-samples/Samples/1_Utilities/deviceQuery/

cmake CMakeLists.txt

make SMS="50 60 70 80"

./deviceQuery | grep Capability | cut -c50- > ~/cuda_infos.txt
./deviceQuery | grep "CUDA Driver Version / Runtime Version" | cut -c57- >> ~/cuda_infos.txt

cd ~/

If you are granted a GPU that supports high FP16 compute scaling (e.g the Tesla T4), it is advised to change the mixed_precision parameter in the last cell to "FP16C_FP32A".  
See the detail description on mixed precision support with CIANNA on the [Systeme Requirements](https://github.com/Deyht/CIANNA/wiki/1\)-System-Requirements) wiki page.

#### Clone CIANNA git repository

In [None]:
%%shell

cd /content/

wget https://share.obspm.fr/s/KNw8aYAEjfxJzsR/download/CIANNA_exp_07_02_25.tar.gz
tar -xvzf CIANNA_exp_07_02_25.tar.gz
mv CIANNA_exp_07_02_25 CIANNA

cd CIANNA

#### Compiling CIANNA for the allocated GPU generation

There is no guaranteed forward or backward compatibility between Nvidia GPU generation, and some capabilities are generation specific. For these reasons, CIANNA must be provided the platform GPU generation at compile time.
The following cell will automatically update all the necessary files based on the detected GPU, and compile CIANNA.

In [None]:
%%shell

cd /content/CIANNA

mult="10"
cat ~/cuda_infos.txt
comp_cap="$(sed '1!d' ~/cuda_infos.txt)"
cuda_vers="$(sed '2!d' ~/cuda_infos.txt)"

lim="11.1"
old_arg=$(awk '{if ($1 < $2) print "-D CUDA_OLD";}' <<<"${cuda_vers} ${lim}")

sm_val=$(awk '{print $1*$2}' <<<"${mult} ${comp_cap}")

gen_val=$(awk '{if ($1 >= 80) print "-D GEN_AMPERE"; else if($1 >= 70) print "-D GEN_VOLTA";}' <<<"${sm_val}")

sed -i "s/.*arch=sm.*/\\t\tcuda_arg=\"\$cuda_arg -D CUDA -D comp_CUDA -lcublas -lcudart -arch=sm_$sm_val $old_arg $gen_val\"/g" compile.cp
sed -i "s/\/cuda-[0-9][0-9].[0-9]/\/cuda-$cuda_vers/g" compile.cp
sed -i "s/\/cuda-[0-9][0-9].[0-9]/\/cuda-$cuda_vers/g" src/python_module_setup.py

./compile.cp CUDA PY_INTERF

mv src/build/lib.linux-x86_64-* src/build/lib.linux-x86_64

#### Testing CIANNA installation

**IMPORTANT NOTE**   
CIANNA is mainly used in a script fashion and was not designed to run in notebooks. Every cell code that directly invokes CIANNA functions must be run as a script to avoid possible errors.  
To do so, the cell must have the following structure.

```
%%shell

cd /content/CIANNA

python3 - <<EOF

[... your python code ...]

EOF
```

This syntax allows one to easily edit python code in the notebook while running the cell as a script. Note that all the notebook variables can not be accessed by the cell in this context.


### **MNIST image generator**


#### Data Handling

In [None]:
%cd /content/

import numpy as np
import glob, os
import matplotlib.pyplot as plt

def cosin_schedule(t, T, s):
	return np.cos((t/T+s)/(1.0+s)*(np.pi/2))**2

image_size = 28

if(not os.path.isdir("mnist_dat")):
	os.system("wget https://share.obspm.fr/s/EkYR5B2Wc2gNis3/download/mnist.tar.gz")
	os.system("tar -xvzf mnist.tar.gz")


print ("Reading inputs ... ", end = "", flush=True)

#Loading binary files
raw_data = np.fromfile("mnist_dat/mnist_input.dat", dtype="float32")
raw_data = np.reshape(raw_data, (-1,28*28))

nb_raw_images = np.shape(raw_data)[0]

n_step = 400

min_signal_rate = 0.05**2
max_signal_rate = 1.00**2

p_alpha_t = cosin_schedule(np.linspace(0,n_step, n_step),n_step,0.0008)
p_alpha_t -= np.min(p_alpha_t[:-1])
p_alpha_t /= np.max(p_alpha_t[:-1])

p_alpha_t = p_alpha_t * (max_signal_rate - min_signal_rate) + min_signal_rate


def create_batch(nb_im_size):

	data = np.zeros((nb_im_size,image_size*image_size), dtype="float32")
	targets = np.zeros((nb_im_size,image_size*image_size), dtype="float32")

	for i in range(0,nb_im_size):
		i_d = int(np.random.random()*nb_raw_images)

		patch = raw_data[i_d,:]

		step = np.random.randint(0,n_step-2)

		noise_patch = np.random.normal(loc=0.0, scale=1.0, size=(image_size*image_size))

		patch_in = np.sqrt(p_alpha_t[step+1])*patch + np.sqrt(1.0-p_alpha_t[step+1])*noise_patch

		data[i,:] = patch_in[:]
		targets[i,:] = noise_patch[:]

	return data, targets



#### Visualize random subset of raw data

In [None]:
sq_size = 5

im_index = np.random.randint(0,nb_raw_images, sq_size**2)

fig, axs = plt.subplots(sq_size, sq_size, figsize=(1.2*sq_size,1.2*sq_size), dpi=120, constrained_layout=True)

for i in range(0, sq_size):
  for j in range(0, sq_size):
    axs[i][j].set_axis_off()
    patch = np.reshape(raw_data[i*sq_size+j],(image_size, image_size))
    axs[i][j].imshow(patch, interpolation="bilinear", cmap="binary")
plt.show()

#### Visualize a training noise chain




In [None]:
steps_list = np.arange(0,n_step, int(n_step/10))

print (steps_list)

im_id = 0

patch = raw_data[im_id]

fig, axs = plt.subplots(1, len(steps_list), figsize=(2*len(steps_list),2), dpi=150, constrained_layout=True)

for i in range(len(steps_list)):
  step = steps_list[i]

  noise_patch = np.random.normal(loc=0.0, scale=1.0, size=(image_size*image_size))

  patch_in = np.sqrt(p_alpha_t[step+1])*patch + np.sqrt(1.0-p_alpha_t[step+1])*noise_patch

  axs[i].set_axis_off()
  axs[i].imshow(np.clip(np.reshape(patch_in,(image_size,image_size)),0,1), interpolation="bilinear", cmap="binary")

plt.show()

#### Visualize a training batch

In [None]:
sq_size = 5

data_augm, target_augm = create_batch(sq_size**2)

fig, axs = plt.subplots(sq_size, sq_size, figsize=(1.2*sq_size,1.2*sq_size), dpi=120, constrained_layout=True)

patch = np.zeros((image_size, image_size,3))
for i in range(0, sq_size):
  for j in range(0, sq_size):
    axs[i][j].set_axis_off()
    patch = np.clip(np.reshape(data_augm[i*sq_size+j],(image_size,image_size)),0.0,1.0)
    axs[i][j].imshow(patch, interpolation="bilinear", cmap="binary")
plt.show()

#### Train the diffusion model



In [None]:
%%shell

cd /content/

python3 - <<EOF

import time
import locale
import matplotlib.pyplot as plt
from scipy import signal
from threading import Thread

import numpy as np

import sys, glob, os
sys.path.insert(0,glob.glob('/content/CIANNA/src/build/lib.*/')[-1])
import CIANNA_exp as cnn

def i_ar(int_list):
	return np.array(int_list, dtype="int")

def f_ar(float_list):
	return np.array(float_list, dtype="float32")

def cosin_schedule(t, T, s):
	return np.cos((t/T+s)/(1.0+s)*(np.pi/2))**2


image_size = 28

if(not os.path.isdir("mnist_dat")):
	os.system("wget https://share.obspm.fr/s/EkYR5B2Wc2gNis3/download/mnist.tar.gz")
	os.system("tar -xvzf mnist.tar.gz")

if(not os.path.isdir("fig")):
	os.system("mkdir fig")

print ("Reading inputs ... ", end = "", flush=True)

raw_data = np.fromfile("mnist_dat/mnist_input.dat", dtype="float32")
raw_data = np.reshape(raw_data, (-1,28*28))

nb_raw_images = np.shape(raw_data)[0]

n_step = 400

min_signal_rate = 0.05**2
max_signal_rate = 1.00**2

p_alpha_t = cosin_schedule(np.linspace(0,n_step, n_step),n_step,0.0008)
p_alpha_t -= np.min(p_alpha_t[:-1])
p_alpha_t /= np.max(p_alpha_t[:-1])

p_alpha_t = p_alpha_t * (max_signal_rate - min_signal_rate) + min_signal_rate


def create_batch(nb_im_size):

	data = np.zeros((nb_im_size,image_size*image_size), dtype="float32")
	targets = np.zeros((nb_im_size,image_size*image_size), dtype="float32")

	for i in range(0,nb_im_size):
		i_d = int(np.random.random()*nb_raw_images)

		patch = raw_data[i_d,:]

		step = np.random.randint(0,n_step-2)

		noise_patch = np.random.normal(loc=0.0, scale=1.0, size=(image_size*image_size))

		patch_in = np.sqrt(p_alpha_t[step+1])*patch + np.sqrt(1.0-p_alpha_t[step+1])*noise_patch

		data[i,:] = patch_in[:]
		targets[i,:] = noise_patch[:]

	return data, targets


def data_augm():

	data_augm, targets_augm = create_batch(2048)
	cnn.delete_dataset("TRAIN_buf", silent=1)
	cnn.create_dataset("TRAIN_buf", 2048, data_augm, targets_augm, silent=1)
	return


nb_train = 2048
nb_valid = 256
nb_test = 32

data_train, target_train = create_batch(nb_train)
data_valid, target_valid = create_batch(nb_valid)

# Create denoiser network
cnn.init(in_dim=i_ar([image_size,image_size]), in_nb_ch=1, out_dim=image_size*image_size, \
		bias=0.1, b_size=16, use_wema=1, comp_meth="C_CUDA", dynamic_load=1, mixed_precision="FP16C_FP32A") #Change to C_BLAS or C_NAI

cnn.create_dataset("TRAIN", size=nb_train, input=data_train, target=target_train)
cnn.create_dataset("VALID", size=nb_valid, input=data_valid, target=target_valid)


f_fact = 32

im_s = image_size*image_size

load_epoch = 0
if(load_epoch > 0):
	cnn.load("net_save/net0_s%04d.dat"%load_epoch,load_epoch, bin=1)
else:
	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact   , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	l_1 = cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact   , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([2,2]), nb_filters=f_fact   , stride=i_ar([2,2]), padding=i_ar([0,0]), activation="RELU")

	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*2 , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	l_2 = cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*2 , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([2,2]), nb_filters=f_fact*2 , stride=i_ar([2,2]), padding=i_ar([0,0]), activation="RELU")

	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*4 , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	l_3 = cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*4 , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*4 , stride=i_ar([2,2]), padding=i_ar([0,0]), activation="RELU")

	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*8 , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*8 , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*8 , stride=i_ar([1,1]), padding=i_ar([1,1]), int_padding=i_ar([2,2]), activation="RELU")

	cnn.merge(-1, l_3, "CONCATENATE")
	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*4 , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*4 , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([2,2]), nb_filters=f_fact*4 , stride=i_ar([1,1]), padding=i_ar([1,1]), int_padding=i_ar([1,1]), activation="RELU")

	cnn.merge(-1, l_2, "CONCATENATE")
	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*2 , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact*2 , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([2,2]), nb_filters=f_fact*2 , stride=i_ar([1,1]), padding=i_ar([1,1]), int_padding=i_ar([1,1]), activation="RELU")

	cnn.merge(-1, l_1, "CONCATENATE")
	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact   , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([3,3]), nb_filters=f_fact   , stride=i_ar([1,1]), padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([1,1]), nb_filters=1        , stride=i_ar([1,1]), padding=i_ar([0,0]), activation="LIN" )


#For testing on static input noise
input_images = np.zeros((nb_test,image_size*image_size), dtype="float32")
targets_diff = np.zeros((nb_test,image_size*image_size), dtype="float32")
input_noise = np.random.normal(loc=0.0, scale=1.0, size=(nb_test,image_size*image_size))

patch = np.zeros((image_size, image_size))
interp = "bilinear"


for i in range(load_epoch,2000):
	t = Thread(target=data_augm)
	t.start()

	cnn.train(nb_iter=1, learning_rate=0.0006, end_learning_rate=0.00005, lr_decay=0.001, momentum=0.0, weight_decay=0.00002, wema_rate=0.995,
		confmat=0, save_every=50, save_bin=1, TC_scale_factor=4.0, shuffle_every=10000)
	t.join()

	if(i == 0):
		cnn.perf_eval()

	cnn.swap_data_buffers("TRAIN")

	if((i+1)%50 == 0):
		input_images[:,0:im_s] = input_noise[:,:]
		input_images[:,im_s:] = np.sqrt(1.0-p_alpha_t[-2])

		cnn.create_dataset("TEST", nb_test, input_images, targets_diff)

		#For display of several diffusion step of the same input
		fig, axs = plt.subplots(4, 5, figsize=(10,8), dpi=200, constrained_layout=True)
		for j in range(0,20):
			axs[int(j/5)][int(j%5)].axis('off')

		fig.suptitle("Epoch %d"%(i+1))

		block_size = n_step // 20

		for step in range(1,n_step-3):
			cnn.forward(saving=2, no_error=1, silent=1)

			file_name = "fwd_res/net0_%04d.dat"%(i+1)
			fwd_dat = np.fromfile(file_name, dtype="float32")
			fwd_dat = np.reshape(fwd_dat,(nb_test,image_size*image_size))

			pred_noise = fwd_dat
			new_images = (input_images[:,:] - np.sqrt(1.0-p_alpha_t[n_step-(step+1)])*pred_noise)/(np.sqrt(p_alpha_t[n_step-(step+1)]))

			eta = np.sqrt((1-p_alpha_t[n_step-(step+2)])/(1-p_alpha_t[n_step-(step+1)]))*np.sqrt(1-p_alpha_t[n_step-(step+1)]/p_alpha_t[n_step-(step+2)])

			new_noise = np.random.normal(loc=0.0, scale=1.0, size=(nb_test,image_size*image_size))
			input_images[:,0:im_s] = np.sqrt(p_alpha_t[n_step-(step+2)])*new_images[:,:] + np.sqrt(1.0-p_alpha_t[n_step-(step+2)] - eta**2)*pred_noise + eta*new_noise[:,:]

			j = step/block_size

			if((step+1)%block_size == 0):
				patch[:,:] = np.clip(np.reshape(new_images[0][:],(image_size,image_size)),0.0,1.0)
				axs[int(j/5)][int(j%5)].imshow(patch[:,:], vmax=0.99, vmin=0.01, interpolation=interp, cmap="Greys")

			cnn.delete_dataset("TEST", silent=1)
			cnn.create_dataset("TEST", nb_test, input_images, targets_diff, silent=1)

		#Uncomment to save network prediction
		cnn.forward(saving=2, no_error=1, silent=1)
		cnn.delete_dataset("TEST", silent=1)

		fwd_dat = np.fromfile(file_name, dtype="float32")
		fwd_dat = np.reshape(fwd_dat,(nb_test,image_size*image_size))

		new_images = (input_images[:,0:im_s] - np.sqrt(1.0-p_alpha_t[0])*fwd_dat[:,:])/(np.sqrt(p_alpha_t[0]))

		patch[:,:] = np.clip(np.reshape(new_images[0][:],(image_size,image_size)),0.0,1.0)
		axs[3][4].imshow(patch[:,:], vmax=0.99, vmin=0.01, interpolation=interp, cmap="Greys")

		plt.savefig("fig/one_diff_fwd_%04d.png"%(i+1), dpi=200)
		plt.close()

		#For display of mulitple final generations for multiple inputs
		fig, axs = plt.subplots(4, 5, figsize=(10,8), dpi=200, constrained_layout=True)

		for j in range(0,20):
			axs[int(j/5)][int(j%5)].axis('off')

		fig.suptitle("Epoch %d"%(i+1))

		for j in range(0,20):
			patch[:,:] = np.clip(np.reshape(new_images[j][:],(image_size,image_size)),0.0,1.0)
			axs[int(j/5)][int(j%5)].imshow(patch[:,:], vmax=0.99, vmin=0.01, interpolation=interp, cmap="binary")

		plt.savefig("fig/all_diff_fwd_%04d.png"%(i+1), dpi=200)
		plt.close()


EOF


In [None]:
%cd /content/fig/
from IPython.display import Image

epoch = 50

Image("all_diff_fwd_%04d.png"%(epoch), width=720)


In [None]:
%cd /content/fig/
from IPython.display import Image

epoch = 50

Image("one_diff_fwd_%04d.png"%(epoch), width=720)