## Import relevant libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchsummary import summary
from PIL import Image

## Weighted MSE Loss

In [None]:
def loss_weighted_mse(y_pred, y, loss_weights):
    y_diff_squared = (y - y_pred)**2
    y_diff_squared_weighted = torch.einsum('ijkl,j->ijkl', y_diff_squared, exp_mse_weights)
    weighted_mse = torch.mean(y_diff_squared_weighted)
    return weighted_mse

## Generate weights for exponential MSE Loss

In [None]:
def generate_exp_weights(T, alpha):
    decay_length = 20.0
    augmentation_container = torch.linspace(0, 1, steps=T)
    augmentation_parameter = torch.linspace(0, alpha, steps=T)
    exp_mse_weights_unflipped = torch.exp(-decay_length * augmentation_container)
    exp_mse_weights = torch.flip(exp_mse_weights_unflipped, dims=[0])
    return exp_mse_weights

## PSF

In [None]:
def generate_psf_kernel(sigma=1.0, psf_size=15):

    psf_kernel = torch.zeros(psf_size, psf_size)
    psf_center = psf_size // 2
    for x in range(psf_size):
        for y in range(psf_size):
            psf_kernel[x, y] = torch.exp(torch.tensor(-((x - psf_center) ** 2 + (y - psf_center) ** 2) / (2 * sigma ** 2)))
    psf_kernel /= psf_kernel.sum()

    return psf_kernel

In [None]:
psf_kernel = generate_psf_kernel()

## Define CoordConv2D Layer

In [None]:
class CoordConv2DLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input_tensor):
        _, _, height, width = input_tensor.size()

        # Create x and y coordinate grids
        xx_channel = torch.arange(width).view(1, 1, 1, width).expand(1, 1, height, width).float() / (width - 1)
        yy_channel = torch.arange(height).view(1, 1, height, 1).expand(1, 1, height, width).float() / (height - 1)

        # Concatenate the coordinate channels to the input tensor
        output_tensor = torch.cat([xx_channel, yy_channel], dim=1)
        return output_tensor

## Define FourierConv2D Layer

In [None]:
class FourierConv2DLayer(nn.Module):
    def __init__(self, L):
        super().__init__()
        self.L = L

    def forward(self, x):
        _, num_input_channels, height, width = x.size()

        # Generate frequencies
        base_frequency = 2
        exponent_value = torch.arange(L)
        frequencies = torch.pow(torch.tensor(base_frequency), exponent_value).float()

        # Apply Fourier basis functions
        fourier_features = [torch.sin(frequencies[j] * torch.pi * x) for j in range(L)]
        fourier_features += [torch.cos(frequencies[j] * torch.pi * x) for j in range(L)]

        # Concatenate the Fourier features along the channel dimension
        fourier_features = torch.cat(fourier_features, dim=1)

        return fourier_features

## Define InverseConv2D Layer

In [None]:
class InverseConv2DLayer(nn.Module):
    def __init__(self, L, num_standard_layers, max_reflectance, subsample_factor=2, use_fourier=True):
        super().__init__()
        self.L = L
        self.num_fourier_channels = 2
        self.num_standard_layers = num_standard_layers
        self.use_fourier = use_fourier
        
        self.upsample_layer = nn.Upsample(scale_factor=subsample_factor, mode='nearest')
        self.coordinate_layer = CoordConv2DLayer()
        self.fourier_layer = FourierConv2DLayer(L)
            
        if self.use_fourier:
            self.fourier_layer = FourierConv2DLayer(L)
            self.num_fourier_channels = 4*L
        
        self.standard_hidden_layers = nn.ModuleList(
        [nn.Conv2d(self.num_fourier_channels, self.num_fourier_channels, kernel_size=3, padding='same') for _ in range(num_standard_layers)]
        )
        
        self.standard_output_layer = nn.Conv2d(self.num_fourier_channels, 1, kernel_size=3, padding='same')
        self.downsample_layer = nn.MaxPool2d(kernel_size=subsample_factor, stride=subsample_factor)
        
        # Initialize weights with He uniform variance scaling initializer
        for layer in self.standard_hidden_layers:
            nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='leaky_relu')
            nn.init.zeros_(layer.bias)  # Initialize biases to zero
        nn.init.kaiming_uniform_(self.standard_output_layer.weight, mode='fan_in', nonlinearity='leaky_relu')
        nn.init.zeros_(self.standard_output_layer.bias)

    def forward(self, x):
        x = self.upsample_layer(x)
        output_coordinate = self.coordinate_layer(x)
        
        if self.use_fourier:
            output_fourier = self.fourier_layer(output_coordinate)
        else:
            output_fourier = output_coordinate
        
        x = output_fourier
        for layer in self.standard_hidden_layers:
            x = nn.functional.elu(layer(x))
        
        output = nn.functional.softplus(self.standard_output_layer(x))
        
        use_sigmoidal_output = True
        if use_sigmoidal_output:
            output = max_reflectance * torch.sigmoid(output)
        
        output_downsampled = self.downsample_layer(output)
        output_inverse = output_downsampled

        return output_coordinate, output_fourier, output, output_inverse

## Define Augmentation function

In [None]:
def augment_image(original_image, augmentation_stride=1, blob_intensity=0.01, contrast_steps=2, mode="contrast"):
    
    augmented_image = torch.zeros_like(original_image)
    
    if mode == "translation":
        augmentation_stride = min(augmentation_stride, 19)
        augmented_image[:, :, :, augmentation_stride:] = original_image[:, :, :, :-augmentation_stride]
    elif mode == "elastic":
        elastic_parameter = 50.0 + augmentation_stride*10.0
        elastic_transformer = transforms.ElasticTransform(alpha=elastic_parameter)
        augmented_image = elastic_transformer(original_image)
    elif mode == "blob":
        blob_start = augmentation_stride
        blob_size = 5
        blob_end = blob_start + blob_size
        blob_tensor = torch.zeros(original_image.size())
        blob_tensor[:, :, blob_start:blob_end, blob_start:blob_end] = blob_intensity
        augmented_image = original_image + blob_tensor
    elif mode == "contrast":
        reflectance_max = float(torch.max(original_image).detach().numpy())
        reflectance_axis = torch.linspace(0, reflectance_max, steps=contrast_steps)
        reflectance_augmented = reflectance_axis[augmentation_stride-1]
        augmented_image = torch.abs(original_image - reflectance_augmented)
    
    return augmented_image

## Define AugmentationConv2D Layer

In [None]:
class AugmentationConv2DLayer(nn.Module):
    def __init__(self, T):
        super().__init__()
        self.T = T

    def forward(self, x, alpha):
        _, num_input_channels, height, width = x.size()

        # Apply augmentation
        augmented_features = [
                augment_image(x, augmentation_stride=j+1, blob_intensity=alpha, contrast_steps=self.T) for j in range(self.T)
                ]

        # Concatenate the augmented features along the channel dimension
        augmented_features = torch.cat(augmented_features, dim=1)

        return augmented_features

## Define Microscope CNN Layer

In [None]:
class MicroscopeCNNLayer(nn.Module):
    def __init__(self, psf_kernel):
        super().__init__()
        self.conv_layer = nn.Conv2d(1, 1, kernel_size=psf_kernel.size(0), padding='same', bias=False)
        self.conv_layer.weight = nn.Parameter(psf_kernel.unsqueeze(0).unsqueeze(0), requires_grad=False)
        self.intensity_layer = IntensityLayer()

    def forward(self, x):
        output_conv = self.conv_layer(x)
        output_intensity = self.intensity_layer(output_conv)
        output_final = output_intensity / torch.max(output_intensity)
        return output_conv, output_intensity, output_final

In [None]:
class IntensityLayer(nn.Module):
    def forward(self, x):
        return torch.square(torch.abs(x))

## Define PINN

In [None]:
class PINN(nn.Module):
    def __init__(self, L, T, alpha, num_standard_layers, max_reflectance, subsample_factor, psf_kernel, use_fourier=True):
        super().__init__()
        self.inverse_layer = InverseConv2DLayer(L, num_standard_layers, max_reflectance, subsample_factor, use_fourier=use_fourier)
        self.augmentation_layer = AugmentationConv2DLayer(T)
        self.forward_layer = MicroscopeCNNLayer(psf_kernel)

    def forward(self, x):
        output_coordinate, output_fourier, output, output_inverse = self.inverse_layer(x)
        output_augmentation = self.augmentation_layer(output_inverse, alpha)
        
        batch_size, num_augmentations, height, width = output_augmentation.size()
        
        output_conv = torch.zeros_like(output_augmentation)
        output_intensity = torch.zeros_like(output_augmentation)
        output_final = torch.zeros_like(output_augmentation)
        
        for t in range(self.augmentation_layer.T):
            output_conv_t, output_intensity_t, output_final_t = self.forward_layer(output_augmentation[:,t,:,:].unsqueeze(0))
            output_conv[:,t,:,:] = output_conv_t
            output_intensity[:,t,:,:] = output_intensity_t
            output_final[:,t,:,:] = output_final_t
        
        return output, output_final

## Sensor function

In [None]:
def sensor_func(image, noise_level=0.1, subsample_factor=2):
    # Generate random noise with the same shape as the input image
    noise = noise_level * torch.randn_like(image)

    # Add the scaled noise to the original image
    noisy_image = image + noise
    
    # Apply subsampling using a pooling operation (e.g., MaxPool2d)
    subsampled_image = nn.functional.max_pool2d(noisy_image, kernel_size=subsample_factor, stride=subsample_factor)
    
    # Clip the values to ensure they are within the valid range (0, 1)
    sensor_image = torch.clamp(subsampled_image, 0, 1)

    return sensor_image

## Extract sample image

### MNIST

In [None]:
#mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
#sample_image, label = mnist_train[0]
#sample_image, label = mnist_train[3]

### Shepp Logan or Wavy Fibers

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

sample_image_path = "shepp_logan_phantom.png"
#sample_image_path = "shepp_logan_phantom_complement.png"
#sample_image_path = "wavy_fibers_processed.png"
sample_image_pil = Image.open(sample_image_path)
sample_image_not_normalized = transform(sample_image_pil)
sample_image = sample_image_not_normalized / torch.max(sample_image_not_normalized)

## Visualize sample image

In [None]:
plt.imshow(sample_image.squeeze(), cmap='gray')

## Construct Ground Truth Reflectance

In [None]:
normalized_pixel_intensity = sample_image
alpha = 0.01
reflectance_ground_truth = alpha*normalized_pixel_intensity

## Resize Ground Truth Reflectance

In [None]:
channels, height, width = reflectance_ground_truth.size()
reflectance_ground_truth = reflectance_ground_truth.view(1, channels, height, width)
reflectance_ground_truth.shape

## Generate "Training" Images

In [None]:
apply_model_mismatch = True
if apply_model_mismatch:
    psf_kernel_real = generate_psf_kernel(sigma=2.0, psf_size=21)
    microscope_model = MicroscopeCNNLayer(psf_kernel_real)
else:
    microscope_model = MicroscopeCNNLayer(psf_kernel)

subsample_factor = 16;
apply_sensor = True

#T = 20
T = 2
augmentation_layer = AugmentationConv2DLayer(T)
input_augmentation = augmentation_layer(reflectance_ground_truth, alpha)

batch_size, num_augmentations, height, width = input_augmentation.size()
training_image = torch.zeros(batch_size, num_augmentations, height//subsample_factor, width//subsample_factor)

#microscope_model = MicroscopeCNNLayer(psf_kernel) 

use_negative = False

if use_negative:
    #_, _, microscope_image = microscope_model(reflectance_ground_truth)
    #intensity_max = float(torch.max(microscope_image).detach().numpy())
    #intensity_axis = torch.linspace(0, intensity_max, steps=T)
    _, _, microscope_image_ideal = microscope_model(reflectance_ground_truth)
    intensity_max = float(torch.max(microscope_image_ideal).detach().numpy())
    intensity_axis = torch.linspace(0, intensity_max, steps=T)
    
for t in range(T):
    if use_negative:
        intensity_augmented = intensity_axis[t]
        training_image_t = torch.abs(microscope_image - intensity_augmented)
    else:
        _, _, training_image_t = microscope_model(input_augmentation[:,t,:,:].unsqueeze(0))
        if apply_sensor:
            training_image_t = sensor_func(training_image_t, subsample_factor = subsample_factor)
    training_image[:,t,:,:] = training_image_t

## Check Training Image shape

In [None]:
training_image.shape

## Visualize Training Images

In [None]:
fig, axs = plt.subplots(1, T, figsize=(15, 3))
for t in range(T):
    axs[t].imshow(training_image[:,t,:,:].squeeze().detach().numpy(), cmap='gray')

In [None]:
plt.imshow(training_image[:,0,:,:].squeeze().detach().numpy(), cmap='gray')

## Define configurations for Inverse Layer

In [None]:
L = 4
num_standard_layers = 4
max_reflectance = alpha
use_fourier = True

## Network training

### Model summary

In [None]:
training_image.size()

In [None]:
pinn_model_dummy = PINN(L, T, alpha, num_standard_layers, max_reflectance, subsample_factor, psf_kernel, use_fourier)
_, channels_dummy, height_dummy, width_dummy = training_image.size()
summary(pinn_model_dummy, input_size=(channels_dummy, height_dummy, width_dummy))

### Model training

In [None]:
# Initialize PINN model
pinn_model = PINN(L, T, alpha, num_standard_layers, max_reflectance, subsample_factor, psf_kernel, use_fourier)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(pinn_model.parameters(), lr=1e-4)

# Set up the exponential learning rate scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=np.exp(np.log(5e-6 / 1e-4) / 10000))

# Set the regularization strengths (lambdas)
lambda_boundary = 0.25

# Number of training epochs
num_epochs = 10000

# Define list to store loss values
loss_list = []

# Define list to store intermediate outputs for reflectance
intermediate_output_reflectance = []

# Generate exp mse loss weights
use_exp_mse_weights = True
if use_exp_mse_weights:
    exp_mse_weights = generate_exp_weights(T, alpha)

# Training loop
for epoch in range(num_epochs):
    # Forward pass
    iter_reflectance, iter_image = pinn_model(training_image)

    # Calculate the mse loss
    if use_exp_mse_weights:
        loss_mse = loss_weighted_mse(iter_image, training_image, exp_mse_weights)
    else:
        loss_mse = criterion(iter_image, training_image)
    
    # Calculate the boundary regularization loss
    loss_boundary = lambda_boundary * (
          torch.square(iter_image[:, :, 0, :] - 0).mean()
        + torch.square(iter_image[:, :, -1, :] - 0).mean()
        + torch.square(iter_image[:, :, :, 0] - 0).mean()
        + torch.square(iter_image[:, :, :, -1] - 0).mean()
    )
    
    # Calculate the total loss
    loss_total = loss_mse + loss_boundary
    loss = loss_total

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Update the learning rate
    scheduler.step()

    # Print training statistics
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
    loss_list.append(loss.item())

## Display Loss history

In [None]:
plt.figure()
plt.plot(loss_list)
plt.xlabel('Iteration')
plt.ylabel('Training Loss')
plt.title('Training Loss History')

## Generate predictions

In [None]:
predicted_reflectance, _ = pinn_model(training_image)
_, _, predicted_image = microscope_model(predicted_reflectance)

## Display results

In [None]:
ground_truth_plot = reflectance_ground_truth.squeeze().numpy()
predicted_reflectance_plot = predicted_reflectance.squeeze().detach().numpy()

if use_negative:
    microscope_output_plot = microscope_image.squeeze().detach().numpy()
else:
    unaugmented_output_plot = training_image[:,0,:,:].squeeze().detach().numpy()

fig, (ax_1, ax_2, ax_3) = plt.subplots(1, 3, figsize=(14, 6))
fig.suptitle("Summary")

ax_1.imshow(ground_truth_plot, cmap='gray')
ax_1.set_ylabel('Y-coordinate')
ax_1.set_title('Ground Truth')

if use_negative:
    ax_2.imshow(microscope_output_plot, cmap='gray')
    ax_2.set_title('Microscope Output')
else:
    ax_2.imshow(unaugmented_output_plot, cmap='gray')
    ax_2.set_title('Unaugmented Output')
ax_2.set_xlabel('X-coordinate')

ax_3.imshow(predicted_reflectance_plot, cmap='gray')
ax_3.set_title('PINN Prediction')

plt.show()

In [None]:
fig, (ax_1, ax_2) = plt.subplots(1, 2, figsize=(10, 6))
fig.suptitle("Visualize Reflectance")

ax_1.imshow(ground_truth_plot, cmap='gray')
ax_1.set_xlabel('X-coordinate')
ax_1.set_ylabel('Y-coordinate')
ax_1.set_title('Ground Truth')

ax_2.imshow(predicted_reflectance_plot, cmap='gray')
ax_2.set_title('PINN Prediction')

In [None]:
predicted_image.shape

In [None]:
predicted_reflectance.shape

In [None]:
plt.imshow(predicted_image.squeeze().detach().numpy(), cmap='gray')

In [None]:
x, y = torch.meshgrid(torch.arange(predicted_reflectance_plot.shape[0]), torch.arange(predicted_reflectance_plot.shape[1]))

fig= plt.figure(figsize=(10, 6))
fig.suptitle("Surface plot of the Predicted Reflectance")

ax_1 = fig.add_subplot(111, projection='3d')
surface = ax_1.plot_surface(x, y, predicted_reflectance_plot, cmap='viridis')
plt.xlabel('X-coordinate')
plt.ylabel('Y-coordinate')

plt.show()

In [None]:
fig= plt.figure(figsize=(10, 6))
fig.suptitle("Surface plot of the Ground Truth")

ax_1 = fig.add_subplot(111, projection='3d')
surface = ax_1.plot_surface(x, y, ground_truth_plot, cmap='viridis')
plt.xlabel('X-coordinate')
plt.ylabel('Y-coordinate')

plt.show()

In [None]:
error_plot = predicted_reflectance_plot - ground_truth_plot

In [None]:
fig= plt.figure(figsize=(10, 6))
fig.suptitle("Surface plot of the Error")

ax_1 = fig.add_subplot(111, projection='3d')
surface = ax_1.plot_surface(x, y, error_plot, cmap='viridis')
plt.xlabel('X-coordinate')
plt.ylabel('Y-coordinate')

plt.show()

In [None]:
plt.hist(error_plot)
plt.xlabel('error value')
plt.ylabel('frequency')
plt.title("Error Histogram")
plt.show()

In [None]:
np.square(error_plot).mean()