In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from snmfem.conf import DATASETS_PATH

from snmfem.models import EDXS

In [None]:
# Load an EDXS model
model_parameters = {"params_dict" : {"c0" : 4.8935e-05, 
                                          "c1" : 1464.19810,
                                          "c2" : 0.04216872,
                                          "b0" : 0.15910789,
                                          "b1" : -0.00773158,
                                          "b2" : 8.7417e-04},
                         "db_name" : "default_xrays.json",
                         "e_offset" : 0.208,
                         "e_scale" : 0.01,
                         "e_size": 1980,
                         "width_slope" : 0.01,
                         "width_intercept" : 0.065,
                         "seed" : 1}

model = EDXS(**model_parameters)
e = model.x

You need to run the script `python script/generate_synthetic_dataset.py` in order to get the data for the next cell.

In [None]:
# load an EDXS sample
sample = DATASETS_PATH / Path("aspim037_N100_2ptcls_brstlg") / Path("sample_0.npz")
# load data
data = np.load(sample)
X = data["X"]
Xdot = data["Xdot"]
nx, ny, ns = X.shape
Xflat = X.transpose([2,0,1]).reshape(ns, nx*ny)
Xdotflat = Xdot.transpose([2,0,1]).reshape(ns, nx*ny)
densities = data["densities"]
phases = data["phases"]
true_spectra = np.expand_dims(densities, axis=1) * phases
true_maps = data["weights"]
k = true_maps.shape[2]
true_maps_flat = true_maps.transpose([2,0,1]).reshape(k,nx*ny)
assert(true_maps.shape[:2] == (nx, ny))
G = data["G"]


In [None]:
A = true_maps.reshape(-1, true_maps.shape[2]).T
GP = np.linalg.lstsq(A.T, Xflat.T)[0]
P = np.linalg.lstsq(G, GP.T)[0]


In [None]:
# parameters
c0 = 4.8935e-05
c1 = 1464.19810
c2 = 0.04216872
b0 = 0.15910789
b1 = -0.00773158
b2 = 8.7417e-04

# Numpy implementation


In [None]:
def Bmatrix(b0, b1, b2, c0, c1, c2, e):
    beta = b0/e + b1 + b2*e
    Gamma = np.exp(-c2 / e**3) * (1 - np.exp(-c1 / e**3))
    alpha = e**3 * (1 - np.exp(- c0 / e**3)) / c0
    B = np.expand_dims(beta * Gamma * alpha, axis=1 )
    return B

def loss(b0, b1, b2, c0, c1, c2, e, G, P, A, X):
    B = Bmatrix(b0, b1, b2, c0, c1, c2, e)
    GP = G @ P
    GPA = (GP + B) @ A
    return - np.mean(X*np.log(GPA)) + np.mean(GPA)

In [None]:
B = Bmatrix(b0, b1, b2, c0, c1, c2, e)
l = loss(b0, b1, b2, c0, c1, c2, e, G, P, A, Xflat)


# Torch implementation

In [None]:
import torch
from torch.autograd import Variable


In [None]:
t_c0 = Variable(torch.tensor(c0), requires_grad=True)
t_c1 = Variable(torch.tensor(c1), requires_grad=True)
t_c2 = Variable(torch.tensor(c2), requires_grad=True)
t_b0 = Variable(torch.tensor(b0), requires_grad=True)
t_b1 = Variable(torch.tensor(b1), requires_grad=True)
t_b2 = Variable(torch.tensor(b2), requires_grad=True)

t_e = Variable(torch.tensor(e), requires_grad=False)

t_G = Variable(torch.tensor(G), requires_grad=False)
t_P = Variable(torch.tensor(P), requires_grad=False)
t_A = Variable(torch.tensor(A), requires_grad=False)
t_X = Variable(torch.tensor(Xflat), requires_grad=False)


In [None]:
def torch_Bmatrix(b0, b1, b2, c0, c1, c2, e):
    beta = b0/e + b1 + b2*e
    Gamma = torch.exp(-c2 / e**3) * (1 - torch.exp(-c1 / e**3))
    alpha = e**3 * (1 - torch.exp(- c0 / e**3)) / c0
    B = torch.unsqueeze(beta * Gamma * alpha, axis=1 )
    return B

def torch_loss(b0, b1, b2, c0, c1, c2, e, G, P, A, X):
    B = torch_Bmatrix(b0, b1, b2, c0, c1, c2, e)
    GP = torch.matmul(G, P)
    GPA = torch.matmul((GP + B), A )
    return - torch.sum(X*torch.log(GPA)) + torch.sum(GPA)

### Check that the forward functions gives the same results in torch and in numpy

In [None]:
B_torch = torch_Bmatrix(t_b0, t_b1, t_b2, t_c0, t_c1, t_c2, t_e)

np.linalg.norm(B - B_torch.detach().numpy())/np.linalg.norm(B)

In [None]:
l_torch = torch_loss(t_b0, t_b1, t_b2, t_c0, t_c1, t_c2, t_e, t_G, t_P, t_A, t_X)

np.linalg.norm(l - l_torch.detach().numpy())/np.linalg.norm(l)

### Compute the gradients

In [None]:
t_c0.grad = torch.zeros(t_c0.size())
t_c1.grad = torch.zeros(t_c1.size())
t_c2.grad = torch.zeros(t_c2.size())
t_b0.grad = torch.zeros(t_b0.size())
t_b1.grad = torch.zeros(t_b1.size())
t_b2.grad = torch.zeros(t_b2.size())

l_torch = torch_loss(t_b0, t_b1, t_b2, t_c0, t_c1, t_c2, t_e, t_G, t_P, t_A, t_X)
l_torch.backward()

# here they are! The number are quite big indeed. That is likely to cause problem for the optimization...
t_c0.grad, t_c1.grad, t_c2.grad, t_b0.grad, t_b1.grad,t_b2.grad


In [None]:
def calc_b(E, b0=None, b1=None, b2=None, c0=None, c1=None, c2=None):
    """
    Function to calculate the values of B according to the model.
    This function takes arguments to allow calculation with temporary values of the parameters.
    """
    return (
        chapman_brstlg(E,b0, b1, b2) * detector(E,c1, c2) * self_abs(E,c0)
    )

def chapman_brstlg(E, b0=None, b1=None, b2=None):
    """
    Bremsstrahlung modelling function.
    This function takes arguments to allow calculation with temporary values of the parameters.
    See Chapman et al., 1984, J. of microscopy, vol. 136, pp. 171
    """
    return (1.0 / E[:, np.newaxis] + b1 + b2 * E[:, np.newaxis]) * b0

def detector(E, c1=None, c2=None):
    """
    Detector modelling function.
    This function takes arguments to allow calculation with temporary values of the parameters.
    Absorption in the dead layer * Photons not absorbed in the detector
    """
    return np.exp(-c2 / np.power(E[:, np.newaxis], 3)) * (
        1 - np.exp(-c1 / np.power(E[:, np.newaxis], 3))
    )

def self_abs(E, c0=None):
    """
    self-absorption modelling function.
    This function takes arguments to allow calculation with temporary values of the parameters.
    Phi rho z model with a constant Phi rho z function
    """
    return (
        np.power(E[:, np.newaxis], 3)
        * (1 - np.exp(-c0 / np.power(E[:, np.newaxis], 3)))
        / c0
    )

def calc_db1(E,b0,c0,c1,c2):
    """
    Partial derivative of B with respect to b1
    """
    return detector(E,c1,c2) * self_abs(E,c0) * b0

def calc_db2(E,b0,c0,c1,c2):
    """
    Partial derivative of B with respect to b2
    """
    return E[:, np.newaxis] * detector(E,c1,c2) * self_abs(E,c0) * b0

def calc_dc0(E,b0,b1,b2,c0,c1,c2):
    """
    Partial derivative of B with respect to c0
    """
    return (
        chapman_brstlg(E,b0,b1,b2)
        * detector(E,c1,c2)
        * np.power(E[:, np.newaxis], 3)
        / np.power(c0, 2)
        * (
            np.exp(-c0 / np.power(E[:, np.newaxis], 3))
            - 1
            + np.exp(-c0 / np.power(E[:, np.newaxis], 3))
            * c0
            / np.power(E[:, np.newaxis], 3)
        )
    )

def calc_dc1(E,b0,b1,b2,c0,c1,c2):
    """
    Partial derivative of B with respect to c1
    """
    return (
        chapman_brstlg(E,b0,b1,b2)
        * self_abs(E,c0)
        * np.exp(-c2 / np.power(E[:, np.newaxis], 3))
        * np.exp(-c1 / np.power(E[:, np.newaxis], 3))
        / np.power(E[:, np.newaxis], 3)
    )

def calc_dc2(E,b0,b1,b2,c0,c1,c2):
    """
    Partial derivative of B with respect to c2
    """
    return (
        -chapman_brstlg(E,b0,b1,b2)
        * self_abs(E,c0)
        * np.exp(-c2 / np.power(E[:, np.newaxis], 3))
        * (1 - np.exp(-c1 / np.power(E[:, np.newaxis], 3)))
        / np.power(E[:, np.newaxis], 3)
    )

def dLdB(x_matr,d_matr,a_matr):
    """
    Partial derivative of L with respect to B
    """
    return -(
        x_matr.clip(min=1e-150) / (d_matr @ a_matr)
    ) @ a_matr.T + np.sum(a_matr, axis=1)

In [None]:
ext_c0 = np.array(3*[c0])
ext_c1 = np.array(3*[c1])
ext_c2 = np.array(3*[c2])
ext_b0 = np.array(3*[b0])
ext_b1 = np.array(3*[b1])
ext_b2 = np.array(3*[b2])
myB = calc_b(e,ext_b0,ext_b1/ext_b0,ext_b2/ext_b0,ext_c0,ext_c1,ext_c2)

In [None]:
myD = (GP.T + myB)
mydL = dLdB(Xflat,myD,A)

In [None]:
grad_b1 = np.einsum("i...,i...", mydL, calc_db1(e,ext_b0,ext_c0,ext_c1,ext_c2))
grad_b2 = np.einsum("i...,i...", mydL, calc_db2(e,ext_b0,ext_c0,ext_c1,ext_c2))
grad_c0 = np.einsum("i...,i...", mydL, calc_dc0(e,ext_b0,ext_b1/ext_b0,ext_b2/ext_b0,ext_c0,ext_c1,ext_c2))
grad_c1 = np.einsum("i...,i...", mydL, calc_dc1(e,ext_b0,ext_b1/ext_b0,ext_b2/ext_b0,ext_c0,ext_c1,ext_c2))
grad_c2 = np.einsum("i...,i...", mydL, calc_dc2(e,ext_b0,ext_b1/ext_b0,ext_b2/ext_b0,ext_c0,ext_c1,ext_c2))



In [None]:
print(grad_c0)
print(grad_c1)
print(grad_c2)
print(grad_b1)
print(grad_b2)

In [None]:
t_c0.grad, t_c1.grad, t_c2.grad, t_b0.grad, t_b1.grad,t_b2.grad