### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
from tqdm import tqdm
from scipy.special import erf
import scipy.stats

from dmipy.core.acquisition_scheme import acquisition_scheme_from_bvalues
from dmipy.core.modeling_framework import MultiCompartmentSphericalMeanModel
from dmipy.signal_models import sphere_models, cylinder_models, gaussian_models

from scipy.io import savemat

  from .autonotebook import tqdm as notebook_tqdm


### Calculate ADC' (Prime)
#### I have a feeling this would be much more useful if it took arrays as inputs and outputted an array. But this current method is good for the list comprehension line. 

In [2]:
def calc_adc_prime(adc, sigma, axr, tm):
   adc_prime = adc * (1 - sigma* np.exp(-tm*axr))
   return adc_prime

def calc_adc_prime_tensor(adc, sigma, axr, tm):
   adc = torch.tensor(adc)
   sigma = torch.tensor(sigma)
   axr = torch.tensor(axr)

   adc_prime = adc * (1 - sigma* torch.exp(-tm*axr))

   return adc_prime


### Simulate Signal

In [3]:
def simulate_sig(adc, sigma, axr, bf, be, tm, acq):
    """Generate an normalised output signal s based on known inputs for a given voxel
Inputs  - adc:      apparent diffusion coefficient [m2/s]
        - sigma:    filter efficiency
        - axr:      exchange rate [1/s]
        - bf:       filter block b-value [m2/s]
        - be:       encoding block b-value [m2/s]
        - tm:       mixing time [s]

Output: - s:        signal (sum of the magnetisations) single value
   Based off Elizabeth's code
    """

    if bf == 0 and tm[acq] == min(tm):
        tm[acq] = np.inf

    #calculate ADC as fnc of mixing time

    adc_prime = calc_adc_prime(adc,sigma,axr,tm[acq])

    #compute signal
    normalised_signal = np.exp(-adc_prime*be)

    return normalised_signal, adc_prime
    

In [4]:
def simulate_sig_tensor(adc, sigma, axr, bf, be, tm, acq):
    """Generate an normalised output signal s based on known inputs for a given voxel
Inputs  - adc:      apparent diffusion coefficient [m2/s]
        - sigma:    filter efficiency
        - axr:      exchange rate [1/s]
        - bf:       filter block b-value [m2/s]
        - be:       encoding block b-value [m2/s]
        - tm:       mixing time [s]

Output: - s:        signal (sum of the magnetisations) single value
   Based off Elizabeth's code
    """
    #maybe move within the forward function
    if bf == 0 and tm[acq] == min(tm):
        tm[acq] = np.inf

    #calculate ADC as fnc of mixing time

    adc_prime = calc_adc_prime_tensor(adc,sigma,axr,tm[acq])
    
    be = torch.tensor(be)  # Convert be to a PyTorch tensor

    normalised_signal = torch.exp(-adc_prime * be)

    return normalised_signal, adc_prime
    

In [5]:
# Attempt at simulating the signal using arrays
"""
nvox1 = 15

bf = np.array([0, 0, 250, 250, 250, 250, 250, 250]) * 1e6   # filter b-values [s/m2]
be = np.array([0, 250, 0, 250, 0, 250, 0, 250]) * 1e6       # encoding b-values [s/m2]
tm = np.array([20, 20, 20, 20, 200, 200, 400, 400]) * 1e-3  # mixing time [s]

be1 = np.tile(be,(nvox1,1))
bef = np.tile(bf,(nvox1,1))
tm1 = np.tile(tm,(nvox1,1))
"""

'\nnvox1 = 15\n\nbf = np.array([0, 0, 250, 250, 250, 250, 250, 250]) * 1e6   # filter b-values [s/m2]\nbe = np.array([0, 250, 0, 250, 0, 250, 0, 250]) * 1e6       # encoding b-values [s/m2]\ntm = np.array([20, 20, 20, 20, 200, 200, 400, 400]) * 1e-3  # mixing time [s]\n\nbe1 = np.tile(be,(nvox1,1))\nbef = np.tile(bf,(nvox1,1))\ntm1 = np.tile(tm,(nvox1,1))\n'

### Initial variables.

In [6]:
nvox = 1000 # number of voxels to simulate

bf = np.array([0, 0, 250, 250, 250, 250, 250, 250]) * 1e-3   # filter b-values [ms/um2]
be = np.array([0, 250, 0, 250, 0, 250, 0, 250]) * 1e-3       # encoding b-values [ms/um2]
tm = np.array([20, 20, 20, 20, 200, 200, 400, 400])  # mixing time [ms]


sim_adc = np.random.uniform(0.1,3.5,nvox)               # ADC, simulated [um2/ms]
sim_sigma = np.random.uniform(0,1,nvox)                 # sigma, simulated [a.u.]
sim_axr = np.random.uniform(0.1,20,nvox) * 1e-3         # AXR, simulated [ms-1]

# simulate signals    
sigs_and_adc_prime = np.array([[simulate_sig(sim_adc[voxel], sim_sigma[voxel], sim_axr[voxel], bf[acq], be[acq], tm, acq) 
                    for acq in range(np.size(tm))] 
                    for voxel in range(nvox)])

#sim_E_vox is the signal
sim_E_vox = sigs_and_adc_prime[:, :, 0]
sim_adc_prime = sigs_and_adc_prime[:, :, 1]

OverflowError: cannot convert float infinity to integer

### Plotting b-value against normalised signal

In [None]:
plt.plot([be[0], be[1]], [sim_E_vox[0,0], sim_E_vox[0,1]], 'bo-')
plt.annotate(tm[0], (be[1], sim_E_vox[0, 1]), textcoords="offset points", xytext=(20,5), ha='center')

plt.plot([be[2], be[3]], [sim_E_vox[0,2], sim_E_vox[0,3]], 'go-')
plt.annotate(tm[2], (be[3], sim_E_vox[0, 3]), textcoords="offset points", xytext=(20,5), ha='center')

plt.plot([be[4], be[5]], [sim_E_vox[0,4], sim_E_vox[0,5]], 'ko-')
plt.annotate(tm[4], (be[5], sim_E_vox[0, 5]), textcoords="offset points", xytext=(20,5), ha='center')

plt.plot([be[6], be[7]], [sim_E_vox[0,6], sim_E_vox[0,7]], 'mo-')
plt.annotate(tm[6], (be[7], sim_E_vox[0, 7]), textcoords="offset points", xytext=(20,5), ha='center')


#plt.title('Scatter plot with 4 lines')
plt.xlabel('Encoding block b-value [m2/s]')
#are units correct
plt.ylabel('Normalised Signal (sum of the magnetisations)')
#unit?
plt.grid(True)
plt.show()


### Creating the neural network

In [None]:
class Net(nn.Module): # this is the neural network
    #defining the init and foward pass functions. 

    def __init__(self,be,bf,tm,nparams,batch_size):
        super(Net, self).__init__()

        self.be = be
        self.bf = bf
        self.tm = tm
        self.batch_size = batch_size

        #defining the layers that we want. 
        # 3 layers with no. of be nodes. 
        self.layers = nn.ModuleList()
        for i in range(3): # 3 fully connected hidden layers
            self.layers.extend([nn.Linear(len(be), len(be)), nn.PReLU()])
            #https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html
        self.encoder = nn.Sequential(*self.layers, nn.Linear(len(be), nparams))

    def forward(self, E_vox):

        params = torch.nn.functional.softplus(self.encoder(E_vox))
        #running a forward pass through the network

        #SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive
        #params contains batch_size x nparams outputs, so each row is adc, sigma and axr.

        #unsqueeze adds an additional dimension. 
        #parameter constraints from Elizabeth matlab 
        adc = torch.clamp(params[:,0].unsqueeze(1), min=0.1, max=3.5)
        """I have a feeling adc is worst plot because it is e9 or e10, wheres others are between 0-1 or 20 
        Maybe it has do with softplus
        """
        sigma = torch.clamp(params[:,1].unsqueeze(1), min=0, max=1)
        axr = torch.clamp(params[:,2].unsqueeze(1), min=.1e-3, max=20e-3)

        E_vox = torch.tensor([
                [simulate_sig_tensor(sim_adc[voxel], sim_sigma[voxel], sim_axr[voxel], self.bf[acq], self.be[acq], self.tm, acq)[0]
                for acq in range(self.tm.size()[0])]
                for voxel in range(self.batch_size)
                ])


        return E_vox, adc, sigma, axr


### NN continued

In [None]:
# define network
nparams = 3
#because of adc, sigma and axr

#converting numpy arrays to pytorch tensors. 
be = torch.tensor(be)
bf = torch.tensor(bf)
tm = torch.tensor(tm)
batch_size = 128

#initilise network
net = Net(be, bf, tm, nparams, batch_size)

#create batch queues for data
#// means divide and round down. 
num_batches = len(sim_E_vox) // batch_size

#import the sim_E_vox array into the dataloader
#drop_last ignores the last batch if it is the wrong size. 
#num_workers is about performance. 

trainloader = utils.DataLoader(torch.from_numpy(sim_E_vox.astype(np.float32)),
                                batch_size = batch_size, 
                                shuffle = True,
                                num_workers = 0, #was 2 previously
                                drop_last = True)

# loss function and optimizer
#choosing which loss function to use. 
#not sure what the optmizer is
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr = 0.0001)

# best loss
best = 1e16
num_bad_epochs = 0
#can increase patience a lot, speed not an issue.
patience = 10

### Training

In [None]:
"""Unchanged and untouch from snighda
Have not made changes. 
"""
# train
for epoch in range(10000): 
    print("-----------------------------------------------------------------")
    print("epoch: {}; bad epochs: {}".format(epoch, num_bad_epochs))
    net.train()
    running_loss = 0.

    #tqdm shows a progress bar. 
    for i, sim_E_vox_batch in enumerate(tqdm(trainloader), 0):
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        pred_E_vox, pred_adc, pred_sigma, pred_axr = net(sim_E_vox_batch)

        """This is getting messy"""
        #pred_E_vox = torch.tensor(pred_E_vox, dtype=torch.float32, requires_grad=True)
        #sim_E_vox_batch = torch.tensor(sim_E_vox_batch, dtype=torch.float32, requires_grad=True)
        
        pred_E_vox = pred_E_vox.clone().detach().requires_grad_(True).float()
        sim_E_vox_batch = sim_E_vox_batch.clone().detach().requires_grad_(True).float()

        loss = criterion(pred_E_vox, sim_E_vox_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
      
    print("loss: {}".format(running_loss))
    # early stopping
    if running_loss < best:
        print("####################### saving good model #######################")
        final_model = net.state_dict()
        best = running_loss
        num_bad_epochs = 0
    else:
        num_bad_epochs = num_bad_epochs + 1
        if num_bad_epochs == patience:
            print("done, best loss: {}".format(best))
            break
print("done")

net.load_state_dict(final_model)

net.eval()
with torch.no_grad():
    final_pred_E_vox, final_pred_adc, final_pred_sigma, final_pred_axr = net(torch.from_numpy(sim_E_vox.astype(np.float32)))

### Plots

In [None]:
final_pred_E_vox_detached = final_pred_E_vox.detach().numpy()
"""Was having numpy pytorch issues, so this line helps fix it a bit."""

plt.scatter(be, sim_E_vox[0,:], label='simulated')
plt.scatter(be, final_pred_E_vox_detached[0,:], label='predicted')
plt.legend()

# plot scatter plots to analyse correlation of predicted free params against ground truth
plt.figure()


param_sim = [sim_adc, sim_sigma, sim_axr]
param_pred = [final_pred_adc, final_pred_sigma, final_pred_axr]
param_name = ['ADC', 'Sigma', 'AXR']

rvals = []

for i,_ in enumerate(param_sim):
    plt.rcParams['font.size'] = '16'
    plt.scatter(param_sim[i], param_pred[i], s=2, c='navy')
    plt.xlabel(param_name[i] + ' Ground Truth')
    plt.ylabel(param_name[i] + ' Prediction')
    rvals.append(scipy.stats.pearsonr(np.squeeze(param_sim[i]), np.squeeze(param_pred[i])))
    plt.tight_layout
    plt.show()

print(rvals)
