# Import Dependency

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data as Data 
from tifffile import imread
import struct

from noise import GaussianMixtureModel
from data import preprocess
from model_linear import DIVNOISING, Discriminator
from predict import predict
import train

# Configure Global Parameters

In [None]:
data_path = "../MNIST/"
model_path = "./MNIST/"
loss_name = "loss.npz"
divnoising_model_name = "divnoising_model_last.net"
divnoising_data_parameters = "divnoising_data_parameters.npz"
divnoising_model_trained = False

In [None]:
factor = 0

### import data ###
with open("../MNIST/raw/train-images-idx3-ubyte", 'rb') as f:
    magic, num, rows, cols = struct.unpack(">IIII", f.read(16))
    signal = np.fromfile(f, dtype = np.uint8).reshape(num, rows, cols).astype("float32") / 255
noisy = signal + factor * np.random.normal(loc = 0.0, scale = 1.0, size = signal.shape)
noisy = np.clip(noisy, 0., 1.)[0:1000]

# Train Adversarial DIVNOISING Model

This section is to train Adversarial DIVNOISING Model for small data set which only includes 10 images.

In [None]:
### configure training parameters ###
patch_size = 28
train_fraction = 0.85
batch_size = 32
epochs = 100
learning_rate = 0.0005
kl_limit = 1e-5
gaussian_std = factor
noise_model = None


### preprocess data ###
train_loss, recon_loss, kl_loss, val_loss = None, None, None, None
train_tensor, val_tensor, mean, std = preprocess(noisy, patch_size, train_fraction)
train_loader = Data.DataLoader(dataset = Data.TensorDataset(train_tensor, train_tensor), 
                               batch_size = batch_size, shuffle = True)
val_loader = Data.DataLoader(dataset = Data.TensorDataset(val_tensor, val_tensor), 
                             batch_size = batch_size, shuffle = True)

### training ###
model = DIVNOISING(mean, std).cuda()
discriminator = Discriminator().cuda()
recon_loss, d_loss, g_loss, val_loss = train.train(model, discriminator, model_path, divnoising_model_name, loss_name, 
                                                   mean, std, train_loader, val_loader, noise_model, 
                                                   gaussian_std, epochs, batch_size, learning_rate, kl_limit)


### plot loss ###
plt.figure(figsize=(20, 5))
plt.subplot(1,3,1)
plt.plot(recon_loss, label='reconstruction')
plt.xlabel("epochs")
plt.ylabel("reconstruction loss")
plt.legend()

plt.subplot(1,3,2)
plt.plot(d_loss, label='discriminator')
plt.plot(g_loss, label='generator')
plt.xlabel("epochs")
plt.ylabel("adversarial loss")
plt.legend()

plt.subplot(1,3,3)
plt.plot(val_loss, label='val')
plt.xlabel("epochs")
plt.ylabel("val loss")
plt.legend()
plt.savefig(model_path + 'loss.jpg')
plt.show()

# Predict Noise-Free Images

### Model in Epoch 5

In [None]:
model = torch.load(model_path + "100divnoising_model_last.net")

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 28

### predict ###
predict(noisy[4:5], signal[2], model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 10

In [None]:
model = torch.load(model_path + "10divnoising_model_last.net")

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 28

### predict ###
predict(noisy[0:1], signal[0], model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 20

In [None]:
model = torch.load(model_path + "20divnoising_model_last.net")

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 28

### predict ###
predict(noisy[0:1], signal[0], model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 50

In [None]:
model = torch.load(model_path + "50divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 75

In [None]:
model = torch.load(model_path + "75divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 100

In [None]:
model = torch.load(model_path + "100divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

Model in Epoch 200

In [None]:
model = torch.load(model_path + "200divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

Model in Epoch 300

In [None]:
model = torch.load(model_path + "300divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

# Generate Images

This section is to use trained Adversarial DIVNOISING model to randonly generate the images.

In [None]:
### configure parameters ###
model = torch.load(model_path + "100divnoising_model_last.net")
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]
vmin=np.percentile(signal[0],0)
vmax=np.percentile(signal[0],98)
### configure parameters ###
num_samples = 10
plt.figure(figsize=(20, 5))
### generate images ###
for i in range(num_samples):
    z = model.reparameterize(torch.zeros(1,64,6,6).cuda(), torch.zeros(1,64,6,6).cuda())
    x = model.decode(z)
    plt.subplot(1, num_samples, i + 1)
    plt.imshow(x.cpu().detach().numpy().reshape(24,24), vmin = vmin, vmax = vmax, cmap = "magma")
    plt.legend()

plt.savefig(model_path + "generated_images.jpg")
plt.show()

# Plot the latent distribution

This section is to plot latent distribution produced by intermediate model.

In [None]:
img = noisy[0].reshape(1, 1, noisy.shape[1], noisy.shape[1])
img_tensor = torch.Tensor(img).cuda()

plt.figure(figsize=(60,60))
for epoch in range(10, 101, 10):
    model = torch.load(model_path + str(epoch) + "divnoising_model_last.net")
    mean, var = model.encode(img_tensor)
    z = model.reparameterize(mean, var)
    plt.subplot(2, 5, epoch // 10)
    plt.hist(z.cpu().detach().numpy().reshape(1568))
    plt.title("epoch " + str(epoch))
    plt.legend()

plt.savefig(model_path + "adversarial_distribution.jpg")
plt.show()

In [None]:
img = noisy[50].reshape(1, 1, noisy.shape[1], noisy.shape[1])
img_tensor = torch.Tensor(img).cuda()
model = torch.load(model_path + str(100) + "divnoising_model_last.net")

for epoch in range(10, 101, 10):
    z = model.reparameterize(torch.zeros(1,32).cuda(), torch.zeros(1,32).cuda())
    x = model.Decoder(z) * model.data_std
    x = x + model.data_mean
    plt.imshow(x.cpu().detach().reshape(28, 28), cmap="gray")
    plt.legend()
    plt.show()