# Import Dependency

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data as Data 
from tifffile import imread
import os
import urllib
import zipfile

from noise import GaussianMixtureModel
from data import preprocess
from model import DIVNOISING, Discriminator
from predict import predict
import train

# Configure Global Parameters

In [None]:
data_path = "../Mouse skull nuclei/"
calibration_data_name = "edgeoftheslide_300offset.tif"
data_name = "example2_digital_offset300.tif"
model_path = "./Mouse/"
noise_model_name = "noise_model.npz"
divnoising_model_name = "divnoising_model_last.net"
divnoising_data_parameters = "divnoising_data_parameters.npz"
noise_model_trained = True

# Download Data

In [None]:
if not os.path.isdir('../'):
    os.mkdir('../')

zipPath="../data/Mouse_skull_nuclei.zip"
if not os.path.exists(zipPath):  
    data = urllib.request.urlretrieve('https://zenodo.org/record/5156960/files/Mouse%20skull%20nuclei.zip?download=1', zipPath)
    with zipfile.ZipFile(zipPath, 'r') as zip_ref:
        zip_ref.extractall("../")

# Train Gaussian Mixture Noise Model

In [None]:
### import data ###

noisy = imread(data_path + calibration_data_name)

### configure noise model parameters ###
num_of_gaussian = 3
num_of_coeff = 2
epochs = 2000
batch_size = 250000
min_variance = 50
weight = None
learning_rate = 0.1

### prepare data ###
signal = np.mean(noisy[:, ...],axis=0)[np.newaxis,...]

### visualize the signal ###
plt.figure(figsize=(12, 12))
plt.subplot(1, 2, 2)
plt.title(label='average (ground truth)')
plt.imshow(signal[0],cmap='gray')
plt.subplot(1, 2, 1)
plt.title(label='single raw image')
plt.imshow(noisy[0],cmap='gray')
plt.show()

### train noise model ###
max_signal = np.max(signal)
min_signal = np.min(signal)
if (noise_model_trained):
    weight = np.load(model_path + noise_model_name)["weight"]
noise_model = GaussianMixtureModel(weight = weight, 
                                   gaussian = num_of_gaussian, coeff = num_of_coeff, 
                                   max_signal = max_signal, min_signal = min_signal, 
                                   min_variance = min_variance)
if (not noise_model_trained):
    noise_model.train(noisy, signal, learning_rate, batch_size, epochs, model_path + noise_model_name)

# Train Adversarial DIVNOISING Model

This section is to train Adversarial DIVNOISING Model for small data set which only includes 10 images.

In [None]:
### import data ###
noisy = imread(data_path + data_name).astype("float32")[0:10]
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure training parameters ###
patch_size = 128
train_fraction = 0.85
batch_size = 32
epochs = 100
learning_rate = 0.0002
kl_limit = 1e-5
gaussian_std = 25


### preprocess data ###
train_loss, recon_loss, kl_loss, val_loss = None, None, None, None
train_tensor, val_tensor, mean, std = preprocess(noisy, patch_size, train_fraction)
train_loader = Data.DataLoader(dataset = Data.TensorDataset(train_tensor, train_tensor), 
                               batch_size = batch_size, shuffle = True)
val_loader = Data.DataLoader(dataset = Data.TensorDataset(val_tensor, val_tensor), 
                             batch_size = batch_size, shuffle = True)

### training ###
model = DIVNOISING(mean, std).cuda()
discriminator = Discriminator().cuda()
recon_loss, d_loss, g_loss, val_loss = train.train(model, discriminator, model_path, divnoising_model_name, loss_name, 
                                                   mean, std, train_loader, val_loader, noise_model, 
                                                   gaussian_std, epochs, batch_size, learning_rate, kl_limit)


### plot loss ###
plt.figure(figsize=(20, 5))
plt.subplot(1,3,1)
plt.plot(recon_loss, label='reconstruction')
plt.xlabel("epochs")
plt.ylabel("reconstruction loss")
plt.legend()

plt.subplot(1,3,2)
plt.plot(d_loss, label='discriminator')
plt.plot(g_loss, label='generator')
plt.xlabel("epochs")
plt.ylabel("adversarial loss")
plt.legend()

plt.subplot(1,3,3)
plt.plot(val_loss, label='val')
plt.xlabel("epochs")
plt.ylabel("val loss")
plt.legend()
plt.savefig(model_path + 'loss.jpg')
plt.show()

# Predict Noise-Free Images

### Model in Epoch 5

In [None]:
model = torch.load(model_path + "5divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 10

In [None]:
model = torch.load(model_path + "10divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 20

In [None]:
model = torch.load(model_path + "20divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 50

In [None]:
model = torch.load(model_path + "50divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 75

In [None]:
model = torch.load(model_path + "75divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 100

In [None]:
model = torch.load(model_path + "100divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy[0:100], signal, model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 200

In [None]:
model = torch.load(model_path + "200divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

### Model in Epoch 300

In [None]:
model = torch.load(model_path + "300divnoising_model_last.net")

### import data ###
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]

### configure parameters###
plot = True
num_samples = 100
num_display = 3
image_size = 256

### predict ###
predict(noisy, signal, model, image_size, num_samples, num_display, plot = True)

# Predict Certain Corp

In [None]:
noisy = imread(data_path + data_name).astype("float32")
signal = np.mean(noisy[:,...],axis=0)[np.newaxis,...][0]
img = noisy[0, 400:528, 200:328].reshape(1, 1, 128, 128)
img_tensor = torch.Tensor(img).cuda()

plt.figure(figsize=(30,30))
plt.subplot(3, 4, 1)
plt.imshow(img.reshape(128, 128), cmap = "magma")
plt.title("original")
plt.legend()
for epoch in range(10, 101, 10):
    model = torch.load(model_path + str(epoch) + "divnoising_model_last.net")
    mean, var = model.encode(img_tensor)
    z = model.reparameterize(mean, var)
    x = model.decode(z)
    plt.subplot(3, 4, epoch // 10 + 1)
    plt.imshow(x.cpu().detach().numpy().reshape(128, 128), cmap = "magma")
    plt.title("epoch " + str(epoch))
    plt.legend()
    
plt.subplot(3, 4, 12)
plt.imshow(signal[400:528, 200:328].reshape(128, 128), cmap = "magma")
plt.title("ground truth")
plt.legend()
plt.savefig(model_path + "denoising.jpg")
plt.show()

# Plot Latent Distribution

This section is to plot latent distribution produced by intermediate model.

In [None]:
noisy = imread(data_path + data_name).astype("float32")
img = noisy[0].reshape(1, 1, noisy.shape[1], noisy.shape[1])
img_tensor = torch.Tensor(img).cuda()

plt.figure(figsize=(20,8))
for epoch in range(10, 101, 10):
    model = torch.load(model_path + str(epoch) + "divnoising_model_last.net")
    mean, var = model.encode(img_tensor)
    z = model.reparameterize(mean, var)
    plt.subplot(2, 5, epoch // 10)
    plt.hist(z.cpu().detach().numpy().reshape(4194304))
    plt.title("epoch " + str(epoch))
    plt.legend()

plt.savefig(model_path + "latent_distribution.jpg")
plt.show()

# Generate Images

In [None]:
noisy = imread(data_path + data_name).astype("float32")

img = noisy[50].reshape(1, 1, noisy.shape[1], noisy.shape[1])
img_tensor = torch.Tensor(img).cuda()

plt.figure(figsize = (20, 8))

for i in range(10):
    model = torch.load(model_path + str(100) + "divnoising_model_last.net")
    z = model.reparameterize(torch.zeros(1,64,4,4).cuda(), torch.zeros(1,64,4,4).cuda())
    x = model.Decoder(z) * model.data_std
    x = x + model.data_mean
    plt.subplot(2, 5, i + 1)
    plt.imshow(x.cpu().detach().reshape(16, 16), cmap="magma")
    plt.legend()

plt.savefig(model_path + "generation.jpg")
plt.show()