## Import relevant libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchsummary import summary
from PIL import Image

## Define Image transformer

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

## Load sample image

In [None]:
#sample_image_path = "shepp_logan_phantom.png"
#sample_image_path = "shepp_logan_phantom_complement.png"
sample_image_path = "wavy_fibers_processed.png"
#sample_image_path = "wavy_fibers_processed_2.png"
sample_image_pil = Image.open(sample_image_path)
sample_image = transform(sample_image_pil)

In [None]:
#mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
#sample_image, label = mnist_train[0]
#sample_image, label = mnist_train[3]

In [None]:
sample_image.shape

In [None]:
plt.imshow(sample_image.squeeze(), cmap='gray')

## Normalize pixel intensity

In [None]:
sample_image = sample_image / torch.max(sample_image)

In [None]:
torch.max(sample_image)

In [None]:
torch.min(sample_image)

## Surface plot of pixel intensity

In [None]:
sample_image_plot = sample_image.squeeze().numpy()

xp, yp = torch.meshgrid(torch.arange(sample_image_plot.shape[0]), torch.arange(sample_image_plot.shape[1]))

fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(xp, yp, sample_image_plot, cmap='viridis', alpha=0.4)

ax.set_xlabel('X-coordinate')
ax.set_ylabel('Y-coordinate')
ax.set_title('Pixel Intensity Map')

## Construct Ground Truth Reflectance

In [None]:
normalized_pixel_intensity = sample_image
alpha = 0.01
reflectance_ground_truth = alpha*normalized_pixel_intensity

In [None]:
reflectance_ground_truth.shape

## Resize Ground Truth Reflectance

In [None]:
channels, height, width = reflectance_ground_truth.size()
reflectance_ground_truth = reflectance_ground_truth.view(1, channels, height, width)
reflectance_ground_truth.shape


## PSF

In [None]:
def generate_psf_kernel(sigma=1.0, psf_size=15):

    psf_kernel = torch.zeros(psf_size, psf_size)
    psf_center = psf_size // 2
    for x in range(psf_size):
        for y in range(psf_size):
            psf_kernel[x, y] = torch.exp(torch.tensor(-((x - psf_center) ** 2 + (y - psf_center) ** 2) / (2 * sigma ** 2)))
    psf_kernel /= psf_kernel.sum()

    return psf_kernel

In [None]:
psf_kernel = generate_psf_kernel()

## CoordConv2D Layer

In [None]:
class CoordConv2DLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input_tensor):
        _, _, height, width = input_tensor.size()

        # Create x and y coordinate grids
        xx_channel = torch.arange(width).view(1, 1, 1, width).expand(1, 1, height, width).float() / (width - 1)
        yy_channel = torch.arange(height).view(1, 1, height, 1).expand(1, 1, height, width).float() / (height - 1)

        # Concatenate the coordinate channels to the input tensor
        output_tensor = torch.cat([xx_channel, yy_channel], dim=1)
        return output_tensor

## FourierConv2D Layer

In [None]:
class FourierConv2DLayer(nn.Module):
    def __init__(self, L):
        super().__init__()
        self.L = L

    def forward(self, x):
        _, num_input_channels, height, width = x.size()

        # Generate frequencies
        base_frequency = 2
        exponent_value = torch.arange(L)
        frequencies = torch.pow(torch.tensor(base_frequency), exponent_value).float()

        # Apply Fourier basis functions
        fourier_features = [torch.sin(frequencies[j] * torch.pi * x) for j in range(L)]
        fourier_features += [torch.cos(frequencies[j] * torch.pi * x) for j in range(L)]

        # Concatenate the Fourier features along the channel dimension
        fourier_features = torch.cat(fourier_features, dim=1)

        return fourier_features

## InverseConv2D Layer

In [None]:
class InverseConv2DLayer(nn.Module):
    def __init__(self, L, num_standard_layers, max_reflectance, subsample_factor=2, use_fourier=True):
        super().__init__()
        self.L = L
        self.num_fourier_channels = 2
        self.num_standard_layers = num_standard_layers
        self.use_fourier = use_fourier
        
        self.upsample_layer = nn.Upsample(scale_factor=subsample_factor, mode='nearest')
        self.coordinate_layer = CoordConv2DLayer()
        self.fourier_layer = FourierConv2DLayer(L)
            
        if self.use_fourier:
            self.fourier_layer = FourierConv2DLayer(L)
            self.num_fourier_channels = 4*L
        
        self.standard_hidden_layers = nn.ModuleList(
        [nn.Conv2d(self.num_fourier_channels, self.num_fourier_channels, kernel_size=3, padding='same') for _ in range(num_standard_layers)]
        )
        
        self.standard_output_layer = nn.Conv2d(self.num_fourier_channels, 1, kernel_size=3, padding='same')
        self.downsample_layer = nn.MaxPool2d(kernel_size=subsample_factor, stride=subsample_factor)
        
        # Initialize weights with He uniform variance scaling initializer
        for layer in self.standard_hidden_layers:
            nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='leaky_relu')
            nn.init.zeros_(layer.bias)  # Initialize biases to zero
        nn.init.kaiming_uniform_(self.standard_output_layer.weight, mode='fan_in', nonlinearity='leaky_relu')
        nn.init.zeros_(self.standard_output_layer.bias)

    def forward(self, x):
        x = self.upsample_layer(x)
        output_coordinate = self.coordinate_layer(x)
        
        if self.use_fourier:
            output_fourier = self.fourier_layer(output_coordinate)
        else:
            output_fourier = output_coordinate
        
        x = output_fourier
        for layer in self.standard_hidden_layers:
            x = nn.functional.elu(layer(x))
        
        output = nn.functional.softplus(self.standard_output_layer(x))
        
        use_sigmoidal_output = True
        if use_sigmoidal_output:
            output = max_reflectance * torch.sigmoid(output)
        
        output_downsampled = self.downsample_layer(output)
        output_inverse = output_downsampled

        #return output_coordinate, output_fourier, output_inverse
        return output_coordinate, output_fourier, output, output_inverse

## Microscope Layer

In [None]:
class MicroscopeCNNLayer(nn.Module):
    def __init__(self, psf_kernel):
        super().__init__()
        self.conv_layer = nn.Conv2d(1, 1, kernel_size=psf_kernel.size(0), padding='same', bias=False)
        self.conv_layer.weight = nn.Parameter(psf_kernel.unsqueeze(0).unsqueeze(0), requires_grad=False)
        self.intensity_layer = IntensityLayer()

    def forward(self, x):
        output_conv = self.conv_layer(x)
        output_intensity = self.intensity_layer(output_conv)
        output_final = output_intensity / torch.max(output_intensity)
        return output_conv, output_intensity, output_final

In [None]:
class IntensityLayer(nn.Module):
    def forward(self, x):
        return torch.square(torch.abs(x))

## PINN

In [None]:
class PINN(nn.Module):
    def __init__(self, L, num_standard_layers, max_reflectance, subsample_factor, psf_kernel, use_fourier=True):
        super().__init__()
        self.inverse_layer = InverseConv2DLayer(L, num_standard_layers, max_reflectance, subsample_factor, use_fourier=use_fourier)
        self.forward_layer = MicroscopeCNNLayer(psf_kernel)

    def forward(self, x):
        #output_coordinate, output_fourier, output_inverse = self.inverse_layer(x)
        output_coordinate, output_fourier, output, output_inverse = self.inverse_layer(x)
        output_conv, output_intensity, output_final = self.forward_layer(output_inverse)
        
        #return output_inverse, output_final
        return output, output_final

## Sensor function

In [None]:
def sensor_func(image, noise_level=0.1, subsample_factor=2):
    # Generate random noise with the same shape as the input image
    noise = noise_level * torch.randn_like(image)

    # Add the scaled noise to the original image
    noisy_image = image + noise
    
    # Apply subsampling using a pooling operation (e.g., MaxPool2d)
    subsampled_image = nn.functional.max_pool2d(noisy_image, kernel_size=subsample_factor, stride=subsample_factor)
    
    # Clip the values to ensure they are within the valid range (0, 1)
    sensor_image = torch.clamp(subsampled_image, 0, 1)

    return sensor_image

## Generate "Training" Image

In [None]:
apply_model_mismatch = True
if apply_model_mismatch:
    psf_kernel_real = generate_psf_kernel(sigma=2.0, psf_size=21)
    microscope_model_real = MicroscopeCNNLayer(psf_kernel_real)
    _, _, training_image = microscope_model_real(reflectance_ground_truth)
else:
    microscope_model = MicroscopeCNNLayer(psf_kernel)
    _, _, training_image = microscope_model(reflectance_ground_truth)

subsample_factor = 2;
apply_sensor = True
if apply_sensor:
    training_image = sensor_func(training_image, subsample_factor = subsample_factor)

## Visualize Inputs

### Image Plot

In [None]:
ground_truth_plot = reflectance_ground_truth.squeeze().numpy()
training_image_plot = training_image.squeeze().numpy()

fig, (ax_1, ax_2) = plt.subplots(1, 2, figsize=(10, 6))
fig.suptitle("Visualize Inputs")

ax_1.imshow(ground_truth_plot, cmap='gray')
ax_1.set_xlabel('X-coordinate of pixel')
ax_1.set_ylabel('Y-coordinate of pixel')
ax_1.set_title('Ground Truth')

ax_2.imshow(training_image_plot, cmap='gray')
ax_2.set_xlabel('X-coordinate of pixel')
ax_2.set_ylabel('Y-coordinate of pixel')
ax_2.set_title('Training Image')

plt.show()

In [None]:
plt.imshow(ground_truth_plot, cmap='gray')

In [None]:
xt, yt = torch.meshgrid(torch.arange(training_image_plot.shape[0]), torch.arange(training_image_plot.shape[1]))

fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(xt, yt, training_image_plot, cmap='viridis', alpha=0.4)

ax.set_xlabel('X-coordinate')
ax.set_ylabel('Y-coordinate')
ax.set_title('Training Image Map')

In [None]:
training_image_positive_plot = training_image_positive.squeeze().detach().numpy()
plt.imshow(training_image_positive_plot, cmap='gray')

### Surface Plot

In [None]:
fig= plt.figure(figsize=(10, 6))
fig.suptitle("Surface plot of the Training Image")

ax_1 = fig.add_subplot(111, projection='3d')
surface = ax_1.plot_surface(xt, yt, training_image_plot, cmap='viridis')
plt.xlabel('X-coordinate')
plt.ylabel('Y-coordinate')

plt.show()

In [None]:
fig= plt.figure(figsize=(10, 6))
fig.suptitle("Surface plot of the Ground Truth")

ax_1 = fig.add_subplot(111, projection='3d')
surface = ax_1.plot_surface(xp, yp, ground_truth_plot, cmap='viridis')
plt.xlabel('X-coordinate')
plt.ylabel('Y-coordinate')

plt.show()

## Check size of Training Image

In [None]:
training_image.shape

## Define configurations for Inverse Layer

In [None]:
L = 4
num_standard_layers = 4
max_reflectance = alpha
use_fourier = True

## Network training

### Model summary

In [None]:
pinn_model_dummy = PINN(L, num_standard_layers, max_reflectance, subsample_factor, psf_kernel, use_fourier)
_, channels_dummy, height_dummy, width_dummy = training_image.size()
summary(pinn_model_dummy, input_size=(channels_dummy, height_dummy, width_dummy))

In [None]:
# Initialize PINN model
pinn_model = PINN(L, num_standard_layers, max_reflectance, subsample_factor, psf_kernel, use_fourier)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(pinn_model.parameters(), lr=1e-4)

# Set up the exponential learning rate scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=np.exp(np.log(5e-6 / 1e-4) / 10000))

# Set the regularization strengths (lambdas)
lambda_boundary = 0.25

# Number of training epochs
num_epochs = 10000

# Define list to store loss values
loss_list = []

# Define list to store intermediate outputs for reflectance
intermediate_output_reflectance = []

# Training loop
for epoch in range(num_epochs):
    # Forward pass
    _, iter_image = pinn_model(training_image)

    # Calculate the mse loss
    loss_mse = criterion(iter_image, training_image)
    
    # Calculate the boundary regularization loss
    loss_boundary = lambda_boundary * (
          torch.square(iter_image[:, :, 0, :] - 0).mean()
        + torch.square(iter_image[:, :, -1, :] - 0).mean()
        + torch.square(iter_image[:, :, :, 0] - 0).mean()
        + torch.square(iter_image[:, :, :, -1] - 0).mean()
    )
    
    # Calculate the total loss
    use_boundary_loss = True
    if use_boundary_loss:
        loss_total = loss_mse + loss_boundary
    else:
        loss_total = loss_mse
    loss = loss_total

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Update the learning rate
    scheduler.step()

    # Print training statistics
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
    loss_list.append(loss.item())

## Display Loss history

In [None]:
plt.figure()
plt.plot(loss_list)
plt.xlabel('Iteration')
plt.ylabel('Training Loss')
plt.title('Training Loss History')

## Generate predictions

In [None]:
predicted_reflectance, predicted_image = pinn_model(training_image)

In [None]:
predicted_reflectance.shape

In [None]:
predicted_image.shape

## Visualize results

In [None]:
predicted_reflectance_plot = predicted_reflectance.squeeze().detach().numpy()

fig, (ax_1, ax_2, ax_3) = plt.subplots(1, 3, figsize=(16, 6))
fig.suptitle("Visualize Reflectance")

ax_1.imshow(ground_truth_plot, cmap='gray')
ax_1.set_ylabel('Y-coordinate')
ax_1.set_title('Ground Truth')

ax_2.imshow(training_image_plot, cmap='gray')
ax_2.set_xlabel('X-coordinate')
ax_2.set_title('Training Image')

ax_3.imshow(predicted_reflectance_plot, cmap='gray')
ax_3.set_title('PINN Prediction')

In [None]:
fig, (ax_1, ax_2) = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle("Visualize Reflectance")

ax_1.imshow(ground_truth_plot, cmap='gray')
ax_1.set_ylabel('Y-coordinate')
ax_1.set_title('Ground Truth')

ax_2.imshow(predicted_reflectance_plot, cmap='gray')
ax_2.set_xlabel('X-coordinate')
ax_2.set_title('PINN Prediction')

In [None]:
xr, yr = torch.meshgrid(torch.arange(predicted_reflectance_plot.shape[0]), torch.arange(predicted_reflectance_plot.shape[1]))

fig= plt.figure(figsize=(10, 6))
fig.suptitle("Surface plot of the Prediction")

ax_1 = fig.add_subplot(111, projection='3d')
surface = ax_1.plot_surface(xr, yr, predicted_reflectance_plot, cmap='viridis')
plt.xlabel('X-coordinate')
plt.ylabel('Y-coordinate')

plt.show()

In [None]:
plt.hist(predicted_reflectance_plot)

In [None]:
predicted_reflectance.shape

In [None]:
reflectance_ground_truth.shape

In [None]:
plt.imshow(predicted_reflectance_plot, cmap='gray')

In [None]:
plt.imshow(training_image_plot, cmap='gray')