<h1>Training a neural network for direct beam tracking</h1>

This notebook demonstrates how to create a neural network for the tracking of direct beams in 4D-STEM data. The network are created with `PyTorch`, and are trained on an experimentally acquired 4D-STEM scan. The steps to pre-process and label data are demonstrated.


<p style="text-align:center;"><img src="Figures/part0.svg" width="850"></p>

<br>
<strong>Notebook presentation is part of the 2024 NordTEMhub workshop on (big) data analysis of 4D-STEM. NTNU Trondheim, 11.06.2024.</strong>

<h2>1. Setup</h2>

<h5>Interactive plotting magic</h5>

In [None]:
%matplotlib qt

<h5>Import necessary libraries</h5>

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import hyperspy.api as hs
import pyxem as pxm
from scipy.ndimage import center_of_mass, gaussian_filter

from torch.nn.functional import relu

<h5>Import Python script with useful functions</h5>

In [None]:
import nn_utility

<h5> Make the outputs deterministic </h5>

In [None]:
torch.backends.cpu.benchmark = True
np.random.seed(0)
torch.manual_seed(0)

<h5> Set neural network device to run on</h5>

In [None]:
device = torch.device('cpu')

In [None]:
device

<h2>2. Build the neural network</h2>

<p><img src="Figures/part2.svg" width="700"></p>

In [None]:
class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv1 = torch.nn.Conv2d(1, 16, 5)
        self.conv2 = torch.nn.Conv2d(16, 32, 5)
        self.fc1 = torch.nn.Linear(32*29*29, 120)
        self.fc2 = torch.nn.Linear(120, 2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = relu(x)
        x = self.pool(x)
        
        x = self.pool(relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

<h5>Initialize the model</h5>

In [None]:
model = Network()

<h5>Set model device and inspect initialization</h5>

In [None]:
model = model.to(device)

In [None]:
model.state_dict()

<h2>3. Create training data</h2>

<h5>Load and inspect the 4D-STEM dataset</h5>

In [None]:
s = hs.load('FeAl_stripes.hspy', lazy=False)

In [None]:
s.plot()

<h5>Extract regions for training data</h5>

In [None]:
s1 = s.inav[59:79, 72:92].data.reshape(20*20, 128, 128)
s2 = s.inav[165:205, 55:95].data.reshape(40*40, 128, 128)

In [None]:
training_data = np.concatenate([s1, s2], axis=0)#.compute()

In [None]:
training_data.shape

In [None]:
shape = training_data.shape[0]

<h5>Divide the training data into training and validation sets</h5>

In [None]:
rand_indx = np.random.choice(shape, size=int(shape * 0.15), replace=False)

In [None]:
rand_indx

In [None]:
valid_set = hs.signals.Signal2D(training_data[rand_indx])

In [None]:
rem_indx = np.setdiff1d(np.arange(shape), rand_indx)

In [None]:
train_set = hs.signals.Signal2D(training_data[rem_indx])

In [None]:
train_set, valid_set

<h5>Filter and threshold the training set</h5>

In [None]:
train_gt = train_set.deepcopy()

In [None]:
train_gt.plot()

In [None]:
train_gt.map(gaussian_filter, sigma=1)

In [None]:
train_gt.plot()

In [None]:
train_gt = train_gt > np.percentile(train_gt, 95)

In [None]:
train_gt.plot()

<h5>Create center position labels</h5>

In [None]:
train_gt.map(center_of_mass)

In [None]:
train_gt.data

<h5>Inspect the training set with labels</h5>

In [None]:
indx = np.random.choice(np.arange(train_set.data.shape[0]), size=4)
indx

In [None]:
nn_utility.plot_patterns(indx, train_set, train_gt)

<h5>Repeat for validation set</h5>

In [None]:
valid_gt = valid_set.deepcopy()
valid_gt.map(gaussian_filter, sigma=1)
valid_gt = valid_gt > np.percentile(valid_gt, 95)
valid_gt.map(center_of_mass)

<h5>Create a torch tensor dataset</h5>

In [None]:
data_train = nn_utility.DatasetLoader(train_set, train_gt)
data_valid = nn_utility.DatasetLoader(valid_set, valid_gt)

In [None]:
data_train.__getitem__(0)[0].shape

In [None]:
train_set.inav[0].data.shape

<h5>Define batch size</h5>

In [None]:
batch_size = 10

<h5>Create dataloader iterable</h5>

In [None]:
data_train_iter = torch.utils.data.DataLoader(data_train, batch_size=batch_size, shuffle=True)
data_valid_iter = torch.utils.data.DataLoader(data_valid, batch_size=batch_size, shuffle=True)

In [None]:
next(iter(data_train_iter))

<h2>4. Training</h2>

<h5>Create a metric</h5>
<p><img src="Figures/part4.svg" width="450"></p>

In [None]:
r2 = lambda pred, gt: 1 - torch.sum((gt - pred) ** 2) / torch.sum((gt - torch.mean(gt)) ** 2)

<h5>Define the last hyperparameters</h5>

In [None]:
num_epochs = 5000
learning_rate = 1e-4

<h5>Define the loss function and optimizer</h5>

In [None]:
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)

<h5>Initiate training</h5>

In [None]:
train_loss, valid_loss, train_met, valid_met  = nn_utility.train(model, optimizer, loss_function, r2,
                                                                 batch_size, data_train_iter, data_valid_iter,
                                                                 device, num_epochs, early_stop=True)

<h5>Save the trained model</h5>

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'training_loss': train_loss,
    'validation_loss': valid_loss,
    'learning_rate': learning_rate,
           }, 'model.pth')

<h5>Inspect the loss during training</h5>

In [None]:
nn_utility.plot_graphs(train_loss, valid_loss, ylim=(0, .2))

In [None]:
nn_utility.plot_graphs(train_met, valid_met, ylim=(.98, 1))

In [None]:
valid_met[-1]

<h2>5. Inference</h2>

<h3>Inference on the validation dataset</h3>

In [None]:
data_test = nn_utility.DatasetLoader(valid_set)
data_test_iter = torch.utils.data.DataLoader(data_test, batch_size=batch_size)

<h5>Create function to do predictions on batches</h5>

In [None]:
def predict(batch, model):
    with torch.no_grad():
        pred = model(batch)
        pred = pred.cpu().numpy()
    return pred

<h5>Loop over all batches</h5>

In [None]:
predictions = list()
for x in data_test_iter:
    predictions.append(predict(x, model))

In [None]:
pred = hs.signals.Signal1D(np.array(predictions).reshape(valid_set.data.shape[0], 2))

<h5>Plot and compare</h5>

In [None]:
indx = np.random.choice(np.arange(valid_set.data.shape[0]), size=4)

In [None]:
nn_utility.plot_patterns(indx, valid_set, valid_gt, pred)

<h3>Inference on a unlabelled region of the dataset</h3>

<h5>Extract region of interest and make torch dataloader iterable</h5>

In [None]:
s.plot()

In [None]:
s_test = s.inav[110:160, 10:150]

In [None]:
# s_test.compute()

In [None]:
s_test

In [None]:
s_test.plot()

In [None]:
s_test = hs.signals.Signal2D(s_test.data.reshape(s_test.data.shape[0] * s_test.data.shape[1],
                                                 s_test.data.shape[2], s_test.data.shape[3]))

In [None]:
infer_set = nn_utility.DatasetLoader(s_test)
infer_data_iter = torch.utils.data.DataLoader(infer_set, batch_size=batch_size, shuffle=False)

<h5>Load the saved model weights</h5>

In [None]:
checkpoint = torch.load('model.pth')

In [None]:
checkpoint

In [None]:
model = Network()

In [None]:
model.state_dict()

In [None]:
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
model.state_dict()

<h5>Set model to inference mode</h5>

In [None]:
model.eval()

<h5>Perform inference</h5>

In [None]:
predictions = list()
for x in infer_data_iter:
    predictions.append(predict(x, model))

In [None]:
shifts_yx = np.array(predictions).reshape((140, 50, 2))

In [None]:
plt.figure()
plt.imshow(shifts_yx[..., 0])

<h2>6. Hidden layer outputs (optional)</h2>

<h5>Extract a small region of the dataset</h5>

In [None]:
dp = s.inav[110, 10]

In [None]:
#dp.compute()

In [None]:
dp.data = np.expand_dims(dp.data, axis=0)

In [None]:
dp.data.shape

In [None]:
dp_set = nn_utility.DatasetLoader(dp)
dp_iter = torch.utils.data.DataLoader(dp_set, batch_size=1, shuffle=False)

<h5>Redefine the model</h5>

In [None]:
class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv1 = torch.nn.Conv2d(1, 16, 5)
        self.conv2 = torch.nn.Conv2d(16, 32, 5)
        self.fc1 = torch.nn.Linear(32*29*29, 120)
        self.fc2 = torch.nn.Linear(120, 2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = relu(x)
        x = self.pool(x)
        return x
        
        x = self.pool(relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = relu(self.fc1(x))
        x = self.fc2(x)   

<h5>Load model weights</h5>

In [None]:
model = Network()

In [None]:
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

<h5>Perform inference</h5>

In [None]:
with torch.no_grad():
    pred = model(next(iter(dp_iter)))

In [None]:
pred = np.squeeze(pred)
dp.data = np.squeeze(dp.data)

In [None]:
pred.shape

<h5>Extract feature maps and feature map weights (filters/kernels)</h5>

In [None]:
feature_map_indx = [0, 2, 8, 11, 14]

In [None]:
maps = pred[feature_map_indx]
maps = np.squeeze(maps)

In [None]:
conv_layer_filters = model.state_dict()['conv1.weight']
conv_layer_filters

In [None]:
conv_layer_filters.shape

In [None]:
filters = conv_layer_filters[feature_map_indx]
filters = torch.squeeze(filters)

<h5>Plot the hidden layer feature maps to see how the convolutional layer extracts features for prediction</h5>

In [None]:
fig, axes = plt.subplots(2, 6, figsize=(14.5, 5), constrained_layout=True)
for i in range(len(feature_map_indx)):
    
    axes[0,i+1].imshow(maps[i])
    axes[1,i+1].imshow(filters[i])

    axes[0,i+1].set_title('Feature map %i'%feature_map_indx[i], fontsize=16)
    axes[0,i+1].set_xticks([]); axes[0,i+1].set_yticks([])
    axes[1,i+1].set_xticks([]); axes[1,i+1].set_yticks([])

axes[1,1].set_ylabel('Filter', fontsize=16)

axes[0,0].imshow(dp); axes[0,0].set_title('Input', fontsize=16)
axes[0,0].set_axis_off(); axes[1,0].set_axis_off()