In [None]:
import os
import torch
import numpy as np
from math import *
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summary

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import colors
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from matplotlib.colors import Normalize
import cmcrameri.cm as cmc
%matplotlib inline

## Load simulation data + pre-processing

In [None]:
data_folder = '/YOUR/DATA/SAVING/PATH'

sim_mat = np.load(os.path.join(data_folder, 'sim_mat.npy'))

# This order assumes dim. 0: strain, 1: tilt_lr, 2: tilt_ud, as generated using the simulation.py script
sim_mat = np.reshape(sim_mat, (sim_mat.shape[0]*sim_mat.shape[1]*sim_mat.shape[2], 
                               sim_mat.shape[3], sim_mat.shape[4]))

strain = np.linspace(-0.005, 0.005, 41) # Strain range
tilt_lr = np.linspace(-0.05, 0.05, 41) # Tilt_LR range in degrees
tilt_ud = np.linspace(-0.1, 0.1, 41) # Tilt_UD range in degrees

labels = np.zeros((41, 41, 41, 3))
for p0 in range(labels.shape[0]):
    for p1 in range(labels.shape[1]):
        for p2 in range(labels.shape[2]):
            labels[p0, p1, p2] = np.array([strain[p0], tilt_lr[p1], tilt_ud[p2]])
labels = np.reshape(labels, (labels.shape[0]*labels.shape[1]*labels.shape[2], labels.shape[3]))
labels[:, 0] *= 100 # Weight the physical parameters equally, same order of magnitude
labels[:, 1] *= 10 
labels[:, 2] *= 5

print('Data shape: ', sim_mat.shape, '| Labels shape: ', labels.shape)
            
labels = np.float32(np.around(labels, 5))
sim_mat = (sim_mat / np.max(sim_mat)) * 7 # Normalize the intensity to approximate experimental values

# Add noise to emulate diffraction experiment
rng = np.random.default_rng()
for i in tqdm(range(sim_mat.shape[0])):
    sim_mat[i] = rng.poisson(sim_mat[i])
    
# Round the data to the nearest integer (number of photons) and convert to float32
sim_mat = np.rint(sim_mat)

## Prepare data for PyTorch

In [None]:
# Make pytorch Dataset

class SimDataset(Dataset):
    """Simulated diffraction dataset. Labels for params: strain (0), tilt_lr (1), tilt_ud (2), in order."""
    
    def __init__(self, data, params, transform=None):
        """
        data (numpy array): simulated diffraction patterns
        params (numpy array): labels
        """
        self.data = data
        self.params = params
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        image = self.data[idx]
        lattice = self.params[idx]
        sample = {'image': image, 'lattice': lattice}
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample
    
class ToTensor(object):
    """Convert numpy arrays in samples to Tensors"""
    
    def __call__(self, sample):
        image = sample['image']
        lattice = sample['lattice']
        return {'image': torch.unsqueeze(torch.from_numpy(image), 0), 'lattice': torch.from_numpy(lattice)}
    
# Initialize
diff_dataset = SimDataset(data=sim_mat, params=labels, transform=ToTensor())

## Hyperparameters and Constants

In [None]:
#NGPUS = torch.cuda.device_count() # Uses all available GPUs
NGPUS = 1
BATCH_SIZE = NGPUS * 64
LR = 0.0001 * NGPUS
print("GPUs:", NGPUS, "| Batch size:", BATCH_SIZE, "| Learning rate:", LR)

EPOCHS = 500
MODEL_SAVE_PATH = '/MODEL/SAVE/PATH'

In [None]:
# Split into training, validation, and test sets

generator0 = torch.Generator().manual_seed(8)
subsets = torch.utils.data.random_split(diff_dataset, [0.8, 0.1, 0.1], generator=generator0)

# Use a DataLoader to iterate through the Datasets

trainloader = DataLoader(subsets[0], batch_size=BATCH_SIZE, shuffle=True)
validloader = DataLoader(subsets[1], batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(subsets[2], batch_size=BATCH_SIZE, shuffle=False)

# Define Model

In [None]:
# Convolutional Neural Network

class NanobeamNN(nn.Module):
    def __init__(self):
        super(NanobeamNN, self).__init__()
        
        self.operation = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(2, 4, 3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            
            nn.Conv2d(4, 8, 3),
            nn.ReLU(),
            nn.Conv2d(8, 16, 3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            
            nn.Conv2d(16, 32, 3),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2))
        )
        self.fc1 = nn.Linear(1024, 3)
        
    def forward(self, x):
        x = self.operation(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

## Check that all dimensions are correct

In [None]:
cnn = NanobeamNN()
for i, data in enumerate(trainloader):
    inputs, labels = data['image'], data['lattice']
    print("inputs:", inputs.shape, labels.shape)
    outputs = cnn(inputs)
    print("outputs:", outputs.shape)
    break

In [None]:
summary(cnn, (1, 1, 64, 64), device='cpu')

## Move model to device 

In [None]:
cnn = NanobeamNN()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if NGPUS > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    cnn = nn.DataParallel(cnn) #Default all devices

cnn = cnn.to(device)

## Optimizer

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(cnn.parameters(), lr=LR)

# Training and Validation Loop

In [None]:
# Update saved model if validation loss is minimum
def update_saved_model(model, path, name):
    if not os.path.isdir(path):
        os.mkdir(path)
    for f in os.listdir(path):
        os.remove(os.path.join(path, f))
    torch.save(model.module.state_dict(), os.path.join(path, name))
    
def generate_state_dict(epoch_num, metrics, optimizer):
    """
    Returns a dictionary of the state_dicts of all states but not the model.
    """
    state = {
        'current_epoch': epoch_num + 1,
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_tracker': metrics
    }
    return state

def save_model_and_states_checkpoint(model, path, epoch_num, metrics, optimizer):
    """Save a checkpoint state that can be loaded to continue training."""
    state_dict = generate_state_dict(epoch_num, metrics, optimizer)
    torch.save(state_dict, os.path.join(path, 'checkpoint.state'))
    update_saved_model(model, path, 'checkpoint_model.pth')

In [None]:
def train(trainloader, metrics):
    running_loss = 0.0
    
    for i, data in enumerate(trainloader):
        inputs, labels = data['image'].to(device), data['lattice'].to(device)
        
        outputs = cnn(inputs)
        
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.detach().item()
        
    metrics['losses'].append(running_loss/i)
    
def validate(validloader, metrics):
    tot_val_loss = 0.0
    
    for j, sample in enumerate(validloader):
        images, ground_truth = sample['image'].to(device), sample['lattice'].to(device)
        
        predicted = cnn(images)
        
        val_loss = criterion(predicted, ground_truth)
        tot_val_loss += val_loss.detach().item()
        
    metrics['val_losses'].append(tot_val_loss/j)
        
    if (tot_val_loss/j < metrics['best_val_loss']):
        print("Saving improved model after Val. Loss improved from %.5f to %.5f" 
              % (metrics['best_val_loss'], tot_val_loss/j))
        metrics['best_val_loss'] = tot_val_loss/j
        update_saved_model(cnn, MODEL_SAVE_PATH, 'best_model.pth')

In [None]:
metrics = {'losses': [], 'val_losses': [], 'best_val_loss': np.inf}

for epoch in tqdm(range(EPOCHS)):
    
    #Set model to train mode
    cnn.train()
    
    #Training loop
    train(trainloader, metrics)
    
    #Switch model to eval mode
    cnn.eval()
    
    #Validation loop
    validate(validloader, metrics)
    
    print('Epoch: %d | Train Loss: %.5f | Val. Loss: %.5f'
          %(epoch, metrics['losses'][-1], metrics['val_losses'][-1]))
    
save_model_and_states_checkpoint(cnn, MODEL_SAVE_PATH, epoch, metrics, optimizer)
    
print('Finished Training')
wandb.finish()

# Visualizations

In [None]:
# Functions for organizing the model statistics

DEVICE = "cpu"

def get_predictions(model, dataloader, batch_size=BATCH_SIZE, device=DEVICE):
    """Returns network predictions and labels by any NanobeamNN model on any SimDataset."""
    
    pred_vals = np.zeros((len(dataloader)-1, batch_size, 3))
    gt_vals = np.zeros(pred_vals.shape)
    model.to(device)
    
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            images = data['image'].to(device)
            labels = data['lattice'].to(device)
            outputs = model(images)
            if i < pred_vals.shape[0]:
                pred_vals[i] = outputs.detach().cpu().numpy()
                gt_vals[i] = labels.detach().cpu().numpy()
            else:
                # The last batch may be a different size
                pred_vals_last = outputs.detach().cpu().numpy()
                gt_vals_last = labels.detach().cpu().numpy()
    
    pred_vals = np.reshape(pred_vals, (pred_vals.shape[0]*pred_vals.shape[1], pred_vals.shape[2]))
    gt_vals = np.reshape(gt_vals, (gt_vals.shape[0]*gt_vals.shape[1], gt_vals.shape[2]))
    pred_vals = np.vstack((pred_vals, pred_vals_last))
    gt_vals = np.vstack((gt_vals, gt_vals_last))
    
    return pred_vals, gt_vals

def make_idx_dict(keys, label_arr, param_num):
    """Returns a dictionary with labels as keys and numpy arrays of indices where the label is the key.
    keys: iterable list-like structure of label values
    label_arr: numpy array of labels
    param_num: 0 = strain, 1 = tilt_lr, 2 = tilt_ud
    """
    idx_dict = dict.fromkeys(keys)
    for key in keys:
        val_arr = np.argwhere(np.around(label_arr[:, param_num], 3) == np.around(key, 3))
        idx_dict[key] = val_arr
    return idx_dict

def test_make_idx_dict(idx_dict):
    """Return the total number of indices accounted for. It should be equal to the total number of labels."""
    counter = 0
    for key in idx_dict.keys():
        counter += idx_dict[key].shape[0]
    return counter

def get_pred_err_stats(idx_arr, pred_arr, param_num):
    """Returns the mean and standard deviation of the prediction error for a parameter.
    idx_arr: array of indices
    pred_arr: array of predictions on a set of data
    param_num: 0 = strain, 1 = thickness, 2 = tilt_lr, 3 = tilt_ud
    """
    temp = np.zeros((idx_arr.shape[0], ))
    j = 0
    for idx in idx_arr:
        temp[j] = pred_arr[idx[0], param_num]
        j += 1
    return np.mean(temp), np.std(temp)   

def combine_stats(idx_dict, pred_arr, param_num):
    stats_arr = np.zeros((len(idx_dict), 2))
    counter = 0
    for key in idx_dict.keys():
        i_arr = idx_dict[key]
        avg, stddev = get_pred_err_stats(i_arr, pred_arr, param_num)
        stats_arr[counter, 0] = avg
        stats_arr[counter, 1] = stddev
        counter += 1
    print(counter)
    return stats_arr

## Load trained model and make predictions

In [None]:
cnn = NanobeamNN()
cnn.load_state_dict(torch.load('trained_model.pth'))

pred_train, train_labels = get_predictions(cnn, trainloader)
pred_test, test_labels = get_predictions(cnn, testloader)

In [None]:
strain_label_range = np.linspace(-0.005, 0.005, 41)*100
print('Strain label range (not actual value range): ', strain_label_range)
tilt_lr_label_range = np.linspace(-0.05, 0.05, 41)*10
print('Tilt_lr label range: ', tilt_lr_label_range)
tilt_ud_label_range = np.linspace(-0.1, 0.1, 41)*5
print('Tilt_ud label range: ', tilt_ud_label_range)

# Make dictionaries for training and test labels
train_s_dict = make_idx_dict(strain_label_range, train_labels, 0)
train_lr_dict = make_idx_dict(tilt_lr_label_range, train_labels, 1)
train_ud_dict = make_idx_dict(tilt_ud_label_range, train_labels, 2)
print('Total indices (train):', test_make_idx_dict(train_s_dict), test_make_idx_dict(train_lr_dict),
      test_make_idx_dict(train_ud_dict))

test_s_dict = make_idx_dict(strain_label_range, test_labels, 0)
test_lr_dict = make_idx_dict(tilt_lr_label_range, test_labels, 1)
test_ud_dict = make_idx_dict(tilt_ud_label_range, test_labels, 2)
print('Total indices (test):', test_make_idx_dict(test_s_dict), test_make_idx_dict(test_lr_dict), 
      test_make_idx_dict(test_ud_dict))

# Compile the prediction statistics
train_stats_s = combine_stats(train_s_dict, pred_train, 0)
train_stats_lr = combine_stats(train_lr_dict, pred_train, 1)
train_stats_ud = combine_stats(train_ud_dict, pred_train, 2)

test_stats_s = combine_stats(test_s_dict, pred_test, 0)
test_stats_lr = combine_stats(test_lr_dict, pred_test, 1)
test_stats_ud = combine_stats(test_ud_dict, pred_test, 2)

# Distribution of absolute prediction error
x0 = train_labels[:, 0]/100 - pred_train[:, 0]/100
y0 = test_labels[:, 0]/100 - pred_test[:, 0]/100
range0 = np.max([np.abs(x0.min()), np.abs(x0.max()), np.abs(y0.min()), np.abs(y0.max())])
bins0 = np.linspace(-range0, range0, 20)

x1 = train_labels[:, 1]/10 - pred_train[:, 1]/10
y1 = test_labels[:, 1]/10 - pred_test[:, 1]/10
range1 = np.max([np.abs(x1.min()), np.abs(x1.max()), np.abs(y1.min()), np.abs(y1.max())])
bins1 = np.linspace(-range1, range1, 20)

x2 = train_labels[:, 2]/10 - pred_train[:, 2]/10
y2 = test_labels[:, 2]/10 - pred_test[:, 2]/10
range2 = np.max([np.abs(x2.min()), np.abs(x2.max()), np.abs(y2.min()), np.abs(y2.max())])
bins2 = np.linspace(-range2, range2, 20)

## Prediction statistics

In [None]:
# Plot the distribution of predicted values by parameter 
# Refer to Figure S2a in the Supplementary Information

f, ax = plt.subplots(figsize=(15, 4), ncols=3)

ax[0].errorbar(np.linspace(-0.005, 0.005, 41), train_stats_s[:, 0]/100, yerr=train_stats_s[:, 1]/100, 
               label='train', fmt='none', capsize=3, color='blue')
ax[0].errorbar(np.linspace(-0.005, 0.005, 41)+0.0001, test_stats_s[:, 0]/100, yerr=test_stats_s[:, 1]/100, 
               label='test', fmt='none', capsize=3, color='red')
ax[0].set_xlabel('Ground truth')
ax[0].set_ylabel('Predicted')
ax[0].set_title('Strain')
ax[0].set_aspect(1)
ax[0].legend()

ax[1].errorbar(np.linspace(-0.05, 0.05, 41), train_stats_lr[:, 0]/10, yerr=train_stats_lr[:, 1]/10, 
               label='train', fmt='none', capsize=3, color='blue')
ax[1].errorbar(np.linspace(-0.05, 0.05, 41)+0.0015, test_stats_lr[:, 0]/10, yerr=test_stats_lr[:, 1]/10, 
               label='test', fmt='none', capsize=3, color='red')
ax[1].set_xlabel('Ground truth')
ax[1].set_ylabel('Predicted')
ax[1].set_title('In-plane rotation (deg.)')
ax[1].set_aspect(1)

ax[2].errorbar(np.linspace(-0.1, 0.1, 41), train_stats_ud[:, 0]/5, yerr=train_stats_ud[:, 1]/5, 
               label='train', fmt='none', capsize=3, color='blue')
ax[2].errorbar(np.linspace(-0.1, 0.1, 41)+0.0025, test_stats_ud[:, 0]/5, yerr=test_stats_ud[:, 1]/5, 
               label='test', fmt='none', capsize=3, color='red')
ax[2].set_xlabel('Ground truth')
ax[2].set_ylabel('Predicted')
ax[2].set_title('Out-of-plane rotation (deg.)')
ax[2].set_aspect(1)

In [None]:
# Plot the distribution of absolute error by parameter
# Refer to Figure S2b in the Supplementary Information

f, ax = plt.subplots(figsize=(14, 4), ncols=3)

ax[0].hist([x0, y0], bins0, label=['train', 'test'])
ax[0].set_xlabel('Pred. error')
ax[0].set_title('Strain')
ax[0].legend()

ax[1].hist([x1, y1], bins1, label=['train', 'test'])
ax[1].set_xlabel('Pred. error (deg.)')
ax[1].set_title('In-plane rotation')

ax[2].hist([x2, y2], bins2, label=['train', 'test'])
ax[2].set_xlabel('Pred. error (deg.)')
ax[2].set_title('Out-of-plane rotation')

plt.tight_layout()

## Diffraction comparison

In [None]:
# Generate simulated diffraction patterns from NanobeamNN predicted parameters
upsampling1 = 1

energy = 11.3
wavelength = 12.398/energy
K=2*pi/wavelength
c = 4.013
l = 2
alf0 = asin(wavelength*l/2/c)
alf = alf0 
twotheta = (2 * alf0) 
X0 =  256 
Xcen = 256 

distance = 0.85
pixelsize = 55e-6/upsampling1*2 # The *2 is for binning

gam0 = twotheta-alf

focal_length = 21.874e-3
outer_angle = 149e-6/2/focal_length # diameter of FZP is 150 um
inner_angle = 77e-6/2/focal_length # diameter of CS is 75 um 

precision = 5e-4 # for fast numerical integration

det_x = np.arange(64*upsampling1).astype(np.float64)
det_y = np.arange(64*upsampling1).astype(np.float64)
det_x = det_x - det_x.mean() + X0 - (X0-Xcen)*upsampling1
det_y -= det_y.mean()

det_xx, det_yy = np.meshgrid(det_x,det_y)

gam = np.arcsin((det_xx-X0)*pixelsize/distance)+gam0

#detector
det_Qx = K*(np.cos(alf)-np.cos(gam))
det_Qz = K*(np.sin(gam)+np.sin(alf))
det_Qy = det_yy*pixelsize/distance*K

upsampling2 = 2

O_x = np.arange(64*upsampling2).astype(np.float64)
O_y = np.arange(64*upsampling2).astype(np.float64)
O_x -= O_x.mean()
O_y -= O_y.mean()

O_xx, O_yy = np.meshgrid(O_x,O_y)
O_xx = O_xx[:,:,np.newaxis,np.newaxis]
O_yy = O_yy[:,:,np.newaxis,np.newaxis]

# origin of the reciprocal space
O_Qx = -O_xx*pixelsize*upsampling1/upsampling2/distance*K*sin(alf)
O_Qz = O_xx*pixelsize*upsampling1/upsampling2/distance*K*cos(alf) # the sign of Qx and Qz are opposite in this convention
O_Qy = O_yy*pixelsize*upsampling1/upsampling2/distance*K
O_angle = np.sqrt(O_yy**2+O_xx**2)*pixelsize*upsampling1/upsampling2/distance
O_donut = (O_angle < outer_angle) * (O_angle > inner_angle)

def Thickness_Fringe(thickness=117, strain=0, tilt_lr=0, tilt_ud=0, precision=5e-4):

    return ((thickness*np.sinc(thickness*(det_Qz-2*pi/c*l/(1+strain)-O_Qz)/pi/2)**2 *\
            (np.abs(det_Qx+2*pi/c*l/(1+strain)*radians(tilt_lr)-O_Qx)<precision) *\
            (np.abs(det_Qy+2*pi/c*l/(1+strain)*radians(tilt_ud)-O_Qy)<precision))*O_donut)\
            .sum(axis=(0,1)).reshape(64,upsampling1,64,upsampling1).mean(axis=(1,3))

In [None]:
# Choose some points from the test data set to compare input images and simulation(output values)
idx_list = np.array([183, 1046, 5021, 5706]) # Random, test set is not shuffled for reproducibility

plot_images = np.zeros((4, 64, 64)) # Input images
plot_labels = np.zeros((4, 3)) # Ground truth labels
for i in range(len(idx_list)):
    for j, data in enumerate(testloader):
        if idx_list[i] // batch_size == j:
            images, labels = data['image'], data['lattice']
            plot_images[i] = images[idx_list[i] % batch_size].squeeze().detach().numpy()
            plot_labels[i] = labels[idx_list[i] % batch_size].squeeze().detach().numpy()
            
sim_pred = np.zeros((4, 64, 64))
for i in range(sim_pred.shape[0]):
    sim_pred[i] = Thickness_Fringe(117, pred_test[idx_list[i], 0]/100, 
                                   pred_test[idx_list[i], 1]/10, pred_test[idx_list[i], 2]/5)

sim_pred = (sim_pred / np.max(sim_pred)) * 7

rng = np.random.default_rng()
for i in range(sim_pred.shape[0]):
    sim_pred[i] = rng.poisson(sim_pred[i])

# Round the data to the nearest integer, but keep dtype as float32
sim_pred = np.rint(sim_pred).astype('float32')

In [None]:
# Plot the comparison between input images and simulated patterns from output parameters
vmin = 1
vmax = 7
normalizer = colors.LogNorm(vmin, vmax)
im = cm.ScalarMappable(norm=normalizer, cmap='jet')

f, ax = plt.subplots(figsize=(11, 4.5), nrows=2, ncols=4)

ax[0, 0].imshow(plot_images[0], interpolation='none', cmap='jet', norm=normalizer)
ax[0, 0].set_xticks([])
ax[0, 0].set_yticks([])
ax[0, 0].set_title('Input 0')

ax[0, 1].imshow(plot_images[1], interpolation='none', cmap='jet', norm=normalizer)
ax[0, 1].set_xticks([])
ax[0, 1].set_yticks([])
ax[0, 1].set_title('Input 1')

ax[0, 2].imshow(plot_images[2], interpolation='none', cmap='jet', norm=normalizer)
ax[0, 2].set_xticks([])
ax[0, 2].set_yticks([])
ax[0, 2].set_title('Input 2')

ax[0, 3].imshow(plot_images[3], interpolation='none', cmap='jet', norm=normalizer)
ax[0, 3].set_xticks([])
ax[0, 3].set_yticks([])
ax[0, 3].set_title('Input 3')

ax[1, 0].imshow(sim_pred[0], interpolation='none', cmap='jet', norm=normalizer)
ax[1, 0].set_xticks([])
ax[1, 0].set_yticks([])
ax[1, 0].set_title('Sim. from pred. 0')

ax[1, 1].imshow(sim_pred[1], interpolation='none', cmap='jet', norm=normalizer)
ax[1, 1].set_xticks([])
ax[1, 1].set_yticks([])
ax[1, 1].set_title('Sim. from pred. 1')

ax[1, 2].imshow(sim_pred[2], interpolation='none', cmap='jet', norm=normalizer)
ax[1, 2].set_xticks([])
ax[1, 2].set_yticks([])
ax[1, 2].set_title('Sim. from pred. 2')

ax[1, 3].imshow(sim_pred[3], interpolation='none', cmap='jet', norm=normalizer)
ax[1, 3].set_xticks([])
ax[1, 3].set_yticks([])
ax[1, 3].set_title('Sim. from pred. 3')

f.subplots_adjust(hspace=0.15, wspace=0.05, right=0.9)
cbar_ax = f.add_axes([0.90, 0.12, 0.02, 0.75])
f.colorbar(im, cax=cbar_ax)
plt.show()

In [None]:
# Selected points from cell above, refer to Figure 2a of the main article

vmin = 1
vmax = 8
normalizer = colors.LogNorm(vmin, vmax)
im = cm.ScalarMappable(norm=normalizer, cmap=cmc.lapaz)

f, ax = plt.subplots(figsize=(3.5, 3), nrows=2, ncols=2, dpi=300)

ax[0, 0].imshow(plot_images[0], interpolation='none', norm=normalizer, cmap=cmc.lapaz)
ax[0, 0].set_xticks([])
ax[0, 0].set_yticks([])
ax[0, 0].set_ylabel('Simulated')

ax[0, 1].imshow(plot_images[2], interpolation='none', norm=normalizer, cmap=cmc.lapaz)
ax[0, 1].set_xticks([])
ax[0, 1].set_yticks([])

ax[1, 0].imshow(sim_pred[0], interpolation='none', norm=normalizer, cmap=cmc.lapaz)
ax[1, 0].set_xticks([])
ax[1, 0].set_yticks([])
ax[1, 0].set_ylabel('Predicted')

ax[1, 1].imshow(sim_pred[2], interpolation='none', norm=normalizer, cmap=cmc.lapaz)
ax[1, 1].set_xticks([])
ax[1, 1].set_yticks([])

f.subplots_adjust(hspace=0.15, wspace=0, right=0.85)
cbar_ax = f.add_axes([0.85, 0.12, 0.03, 0.75])
f.colorbar(im, cax=cbar_ax, label='Photons')