# Trichromatic Neural Étendue Expansion Simulation Code

### This notebook can be used to produce the trichromatic étendue expanded simulation holograms shown in the manuscript and in the supplementary information.

### In the cells below please select one expander type and one target image. For example, to produce a 36x étendue expanded hologram with the neural étendue expander please select 'neural_tri_36x'. To produce a 4x étendue expanded hologram with a random expander [Kuo et al. 2020] please select 'random_4x'. The target images provided are labeled as '000.png', '001.png', and so on.

In [None]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import os
import time
import random

from Expander import get_expander_phase
from SLM import get_slm_phase
from Render import get_intensity
from RI_wvl import RI_660, RI_517, RI_450, wvl_660, wvl_517, wvl_450
from Freq import butterworth, freq_filt
from Img import get_img

In [None]:
### --- BEGIN CONFIG --- ###
# Choose only one of the following expanders.
#expander_type = 'random_4x'
#expander_type = 'random_16x'
#expander_type = 'random_36x'
#expander_type = 'random_64x'
#expander_type = 'neural_tri_4x'
#expander_type = 'neural_tri_16x'
#expander_type = 'neural_tri_36x'
expander_type = 'neural_tri_64x'

# Choose only one of the following target images.
#target_img_name = '000'
#target_img_name = '001'
target_img_name = '002'
### ---   END CONFIG --- ###

if '_4x' in expander_type:
    upsample_factor = 2
elif '_16x' in expander_type:
    upsample_factor = 4
elif '_36x' in expander_type:
    upsample_factor = 6
elif '_64x' in expander_type:
    upsample_factor = 8
else:
    assert('Undefined expander.')

In [None]:
### --- BEGIN CONSTANTS --- ###
R = 512; C = 512
batch_size = 3
up_R = upsample_factor * R
up_C = upsample_factor * C
iterations = 1000; lr=0.1
### ---   END CONSTANTS --- ###

In [None]:
### Set Random Seed for reproducibility
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")
set_seed(1996)

In [None]:
### Define frequency filter ###
H = butterworth(R*upsample_factor,C*upsample_factor,upsample_factor)
H_torch = torch.from_numpy(H.astype(np.float32)); H_torch = H_torch.to('cuda:0')

### Optimize for Red Wavelength (660 nm)

In [None]:
### Define target image (monochromatic at 660 nm) ###
img_name = os.path.join('Target_Images', target_img_name+'.png')
img, wvlpad_R, wvlpad_C = get_img(img_name, up_R, up_C, 'r')
target = torch.from_numpy(img.astype(np.float32)); target = target.to('cuda:0')
freq_target = freq_filt(target, H_torch)

In [None]:
### Define expander (monochromatic at 660 nm) ###
expander_phase = get_expander_phase(expander_type, wvl_660, RI_660)
expander_phase = expander_phase.to('cuda:0')
expander_amp = torch.ones(R * upsample_factor, C * upsample_factor)
expander_amp = expander_amp.to('cuda:0')

In [None]:
### Optimize for SLM ###
slm_amp = torch.ones(R, C); slm_amp = slm_amp.to('cuda:0')
slm_phase = get_slm_phase(freq_target, expander_amp, expander_phase, slm_amp, H_torch, \
                          batch_size, iterations, lr, R, C, upsample_factor, wvlpad_R, wvlpad_C)

In [None]:
### Render etendue expanded hologram with optimized SLM pattern ###
slm_phase_upsample = F.interpolate(slm_phase, scale_factor=(upsample_factor,upsample_factor))
slm_amp_upsample = torch.ones_like(slm_phase_upsample)
loss, output = get_intensity(freq_target, expander_amp, expander_phase, slm_amp_upsample, slm_phase_upsample, H_torch, \
                             batch_size, R, C, upsample_factor, wvlpad_R, wvlpad_C)
print('Loss value is {}'.format(loss))

In [None]:
freq_target_r = F.interpolate(freq_target[None, None, wvlpad_R:-wvlpad_R, wvlpad_C:-wvlpad_C], size=(R * upsample_factor, C * upsample_factor), mode='nearest')
expander_phase_r = expander_phase; expander_amp_r = expander_amp
slm_phase_r = slm_phase; slm_amp_r = slm_amp
output_r = output

### Optimize for Green Wavelength (517 nm)

In [None]:
### Define target image (monochromatic at 517 nm) ###
img_name = os.path.join('Target_Images', target_img_name+'.png')
img, wvlpad_R, wvlpad_C = get_img(img_name, up_R, up_C, 'g')
target = torch.from_numpy(img.astype(np.float32)); target = target.to('cuda:0')
freq_target = freq_filt(target, H_torch)

In [None]:
### Define expander (monochromatic at 517 nm) ###
expander_phase = get_expander_phase(expander_type, wvl_517, RI_517)
expander_phase = expander_phase.to('cuda:0')
expander_amp = torch.ones(R * upsample_factor, C * upsample_factor)
expander_amp = expander_amp.to('cuda:0')

In [None]:
### Optimize for SLM ###
slm_amp = torch.ones(R, C); slm_amp = slm_amp.to('cuda:0')
slm_phase = get_slm_phase(freq_target, expander_amp, expander_phase, slm_amp, H_torch, \
                          batch_size, iterations, lr, R, C, upsample_factor, wvlpad_R, wvlpad_C)

In [None]:
### Render etendue expanded hologram with optimized SLM pattern ###
slm_phase_upsample = F.interpolate(slm_phase, scale_factor=(upsample_factor,upsample_factor))
slm_amp_upsample = torch.ones_like(slm_phase_upsample)
loss, output = get_intensity(freq_target, expander_amp, expander_phase, slm_amp_upsample, slm_phase_upsample, H_torch, \
                             batch_size, R, C, upsample_factor, wvlpad_R, wvlpad_C)
print('Loss value is {}'.format(loss))

In [None]:
freq_target_g = F.interpolate(freq_target[None, None, wvlpad_R:-wvlpad_R, wvlpad_C:-wvlpad_C], size=(R * upsample_factor, C * upsample_factor), mode='nearest')
expander_phase_g = expander_phase; expander_amp_g = expander_amp
slm_phase_g = slm_phase; slm_amp_g = slm_amp
output_g = output

### Optimize for Blue Wavelength (450 nm)

In [None]:
### Define target image (monochromatic at 450 nm) ###
img_name = os.path.join('Target_Images', target_img_name+'.png')
img, wvlpad_R, wvlpad_C = get_img(img_name, up_R, up_C, 'b')
target = torch.from_numpy(img.astype(np.float32)); target = target.to('cuda:0')
freq_target = freq_filt(target, H_torch)

In [None]:
### Define expander (monochromatic at 450 nm) ###
expander_phase = get_expander_phase(expander_type, wvl_450, RI_450)
expander_phase = expander_phase.to('cuda:0')
expander_amp = torch.ones(R * upsample_factor, C * upsample_factor)
expander_amp = expander_amp.to('cuda:0')

In [None]:
### Optimize for SLM ###
slm_amp = torch.ones(R, C); slm_amp = slm_amp.to('cuda:0')
slm_phase = get_slm_phase(freq_target, expander_amp, expander_phase, slm_amp, H_torch, \
                          batch_size, iterations, lr, R, C, upsample_factor, wvlpad_R, wvlpad_C)

In [None]:
### Render etendue expanded hologram with optimized SLM pattern ###
slm_phase_upsample = F.interpolate(slm_phase, scale_factor=(upsample_factor,upsample_factor))
slm_amp_upsample = torch.ones_like(slm_phase_upsample)
loss, output = get_intensity(freq_target, expander_amp, expander_phase, slm_amp_upsample, slm_phase_upsample, H_torch, \
                             batch_size, R, C, upsample_factor, wvlpad_R, wvlpad_C)
print('Loss value is {}'.format(loss))

In [None]:
freq_target_b = F.interpolate(freq_target[None, None, :, :], size=(R * upsample_factor, C * upsample_factor), mode='nearest')
expander_phase_b = expander_phase; expander_amp_b = expander_amp
slm_phase_b = slm_phase; slm_amp_b = slm_amp
output_b = output

### Display Étendue Expanded Hologram and Target Image ###

In [None]:
output_tri = torch.concat([output_r, output_g, output_b], axis=1)
output_tri = output_tri.permute((0,2,3,1))
output_tri = np.clip(output_tri[0,...].detach().cpu().numpy(), 0.0, 1.0)

plt.figure()
plt.title('Étendue Expanded Hologram')
plt.imshow(output_tri)

freq_target_tri = torch.concat([freq_target_r, freq_target_g, freq_target_b], axis=1)
freq_target_tri = freq_target_tri.permute((0,2,3,1))
freq_target_tri = np.clip(freq_target_tri[0,...].detach().cpu().numpy(), 0.0, 1.0)

plt.figure()
plt.title('Target Image')
plt.imshow(freq_target_tri)