# Import all packages

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import tqdm

from torch.utils.data import DataLoader, random_split, Dataset
import torchvision.transforms as transforms
from torchvision import datasets

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from torchvision import datasets

from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
 torch.cuda.is_available()

# Angular Spectrum

In [None]:
def angular_spectrum_propagation(U0, wvl, dx, z):
    """
    Propagates an optical field using the Bandlimited Angular Spectrum Method over a range of distances.

    Parameters:
        U0      : 4D torch tensor (batch, channel, x, y)
        wvl     : float (wavelength of light in meters)
        dx      : float (sampling interval in meters)
        z       : float (propagation distance in meters)

    Returns:
        U1  : the propagated field (same shape as U0)
    """
    batch, channel, original_height, original_width = U0.shape
    pad_height, pad_width = original_height // 2, original_width // 2

    z = torch.tensor([z], device=U0.device)
    # Pad input field
    U0_padded = torch.zeros((batch, channel, original_height + 2 * pad_height, original_width + 2 * pad_width), dtype=U0.dtype, device=U0.device)
    U0_padded[:, :, pad_height:pad_height + original_height, pad_width:pad_width + original_width] = U0

    ny, nx = U0_padded.shape[-2:]
    k = 2 * torch.pi / wvl

    # Create spatial frequency coordinates
    fx = torch.fft.fftfreq(nx, dx, device=U0.device)
    fy = torch.fft.fftfreq(ny, dx, device=U0.device)
    FX, FY = torch.meshgrid(fx, fy)  # Ensure FX and FY have last two dimensions
    FX = FX.reshape(1,1, FX.shape[0], FX.shape[1])
    FY = FY.reshape(1,1, FY.shape[0], FY.shape[1])

    Delta_uy = 1 / (ny * dx)
    Delta_ux = 1 / (nx * dx)
    u_limity = 1 / (torch.sqrt((2 * Delta_uy * z) ** 2 + 1) * wvl)
    u_limitx = 1 / (torch.sqrt((2 * Delta_ux * z) ** 2 + 1) * wvl)
    H_limit = ((FY ** 2 / u_limity ** 2 + FX ** 2 * wvl ** 2) < 1) * ((FX ** 2 / u_limitx ** 2 + FY ** 2 * wvl ** 2) < 1)

    # Kernel
    H = torch.exp(1j * k * z * torch.sqrt(1 - (wvl * FX) ** 2 - (wvl * FY) ** 2))

    # Propagate
    U1_padded = torch.fft.fftshift(torch.fft.ifft2(torch.fft.fft2(torch.fft.ifftshift(U0_padded, dim=(-2, -1))) * H * H_limit), dim=(-2, -1))

    # Extract to initial size
    U1 = U1_padded[:, :, pad_height:pad_height + original_height, pad_width:pad_width + original_width]

    return U1

# Load pretrained model and define U-Net architecture
This model was trained on 3s and 7s with the same defocus distance.

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),  # BatchNorm after Conv2d
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),  # BatchNorm after Conv2d
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

        self.middle = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),  # BatchNorm after Conv2d
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),  # BatchNorm after Conv2d
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),  # BatchNorm after Conv2d
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.BatchNorm2d(1),  # BatchNorm after Conv2d
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

model = torch.load('U-net_defocus.pt', weights_only=False, map_location=torch.device(device))
model.eval()  # Switch to evaluation mode (important for inference)
model.to(device)

# Evaluate Model
We can evaluate the model but be sure it's normalized to [0,1]. So peak intensity should be 1.

In [None]:
def evaluate(model, img):
    """
    Evaluate a single `image` with torch array size (28, 28).
    Returns the model output as a single torch array with `(28, 28)`
    """
    output = model(img.reshape(1,1,28, 28)).reshape(28, 28)

    return output

def plot_images(inputs, outputs, targets):
    # Create a figure with 1 row and 3 columns
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))

    # Plot input image
    ax[0].imshow(inputs.reshape(28, 28).detach().cpu().numpy())
    ax[0].set_title("Input Image")
    ax[0].axis('off')

    # Plot output image
    ax[1].imshow(outputs.reshape(28, 28).detach().cpu().numpy())
    ax[1].set_title("Network Output Image")
    ax[1].axis('off')

    # Plot target image
    ax[2].imshow(targets.reshape(28, 28).detach().cpu().numpy())
    ax[2].set_title("Target Image")
    ax[2].axis('off')

    plt.show()

# Load dataset and prepare
Dataset is a tensor where the first dimension is the different images.

In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])
dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)

only_number = 5
dataset.data = dataset.data[(dataset.targets == only_number)]
dataset.targets = dataset.targets[(dataset.targets == only_number)]

In [None]:
dataset.data.shape

In [None]:
index = 1000

image = torch.tensor(dataset.data[index], dtype=torch.float, device=device) / 255
image_prop = torch.abs(angular_spectrum_propagation(image.reshape(1,1,28,28),633e-9, 3e-6, 300e-6))**2
output = evaluate(model, image_prop)


plot_images(image_prop, output, image)

# Task

- Try to evaluate the network with differents digits (except 3s or 7s)?
- Can you compare the performance of 3s vs 5s under the SSIM metric?
- What happens if you do not use simple MNIST images but instead something more realistic (Fashion MNIST or any image of your choice)?