# DNN Practical: Ag Detection by Muon Spectroscopy

In this notebook, we attempt to solve a real problem in physics using a fully connected DNN.

We have a set of spectra from Muon spectroscopy experiments, from which we would like to detect whether or not a certain element is present in a sample. In this notebook, we are going to train a neural network to detect the presence of Ag. Through this practice, we will encounter and overcome a pitfall in deep learning known as **class imbalance**. We will also explore **early stopping** and saving checkpoints from the best performing model.

## About the data

The data in this example is generated from simulated muon spectroscopy experiments. First the data was generated for each individual element by simulating the spectral emmission lines of that element. Then for the mixed coumpounds the different elemental spectra were mixed in proportion to how much of that element is present in the compound.

In [None]:
!wget https://zenodo.org/records/14230642/files/ag-muon-data-tight.pkl

In [11]:
# torch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets

# helpers
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')

# need some certainty in data processing
np.random.seed(0)

from tqdm.notebook import trange, tqdm

---

# The dataset

### Read raw data

The raw data, which include the constituent elements and the Muon spectra of the samples, are stored in the pickle file `muon/Ag_muon_data.pkl`. We load this file into a `pandas` dataframe and take a quick look.

In [None]:
# read data
df = pd.read_pickle('ag-muon-data-tight.pkl')
#print dimensions
print('Number of samples in the dataset: %d' % len(df['Spectra']))
print('Length of spectra for each sample: %d' % len(df['Spectra'][0]))

# print the first few data
df.head(n=5)

In the above table, the `Elements` and the `Spectra` columns show respectively the elements and the spectra of the samples. There are 138,613 samples in the dataset, and each spectrum is a series of 1000 positive reals.

To get a feel for the complexity of picking out signals with Ag in multinary samples, we can plot some random spectra for three representative cases:

* no Ag
* pure Ag
* Ag-Si binary

Note that we are plotting only the first part of each spectrum. Change `[0:150]` to `[:]` to show the full spectra.

In [None]:
# conditions to select data
conditions = [
# no Ag
('no Ag', np.where(['Ag' not in elements for elements in df['Elements']])[0]),
# pure Ag
('pure Ag', np.where([['Ag'] == elements for elements in df['Elements']])[0]),
# Ag-Si
('Ag-Si binary', np.where([['Ag', 'Si'] == elements for elements in df['Elements']])[0])
]

# plot
ncond = len(conditions)
nplot = 4 # number of plots per condition
fig, axs = plt.subplots(nplot, ncond, dpi=200, figsize=(ncond * 5, nplot * 2), sharex=True, sharey=True)
plt.subplots_adjust(wspace=.1, hspace=.2)
for icond, cond in enumerate(conditions):
    for iplot, idata in enumerate(np.random.choice(cond[1], nplot)):
        axs[iplot, icond].plot(df['Spectra'][idata][150:700], c='C%d' % icond)
        axs[iplot, icond].set_xlabel('Sample %d: %s' % (idata, cond[0]), c='C%d' % icond)
        axs[iplot, icond].set_ylim(0, 100)

### Extract training data

The input data for our network will be the `Spectra` column, and we can use the `to_list()` method to convert it to a numpy array. The output data for our network will be a binary-valued one-hot vector: 0 for no Ag in the sample and 1 otherwise. One-hot encoding can be achieved by a simple for-loop. Also, it is important to normalise each spectrum between 0 and 1.

In [None]:
###### input ######
limit = -1
# convert the 'Spectra' column to numpy
train_x = np.array(df['Spectra'].to_list())[:limit]
# normalise each spectrum to [0, 1]
train_x /= np.max(train_x, axis=1)[:, np.newaxis]

###### output ######
# one-hot encoding: whether Ag is in 'Elements'
train_y = np.array(['Ag' in elements for elements in df['Elements']]).astype(int)[:limit]

# print data shapes
print("Shape of input: %s" % str(train_x.shape))
print("Shape of output: %s" % str(train_y.shape))

##### to torch ######
x_in = torch.from_numpy(train_x).type(torch.float)
y_true = torch.from_numpy(train_y).type(torch.float).unsqueeze(1)

### split into train/val ######

train_lim = int(x_in.shape[0] * 0.8)
x_in_train = x_in[:train_lim]
y_true_train = y_true[:train_lim]

x_in_val = x_in[train_lim:]
y_true_val = y_true[train_lim:]

train_data = [(x_in_train[i], y_true_train[i]) for i in range(x_in_train.shape[0])]
val_data = [(x_in_val[i], y_true_val[i]) for i in range(x_in_val.shape[0])]

### Set up a dataloader, like in the lecture notebook

Use a batch size of 512

**Suggested Answer**

<details> <summary>Show / Hide</summary>
<p>
    
```python
BATCH_SIZE = 512

train_iterator = data.DataLoader(train_data,
                                 shuffle=True,
                                 batch_size=BATCH_SIZE)

valid_iterator = data.DataLoader(val_data,
                                 batch_size=BATCH_SIZE)
```
    
</p>
</details>

# Ag-detection by DNN

## 1. Try out a network


### Build and compile

Based on what we have learnt in [DNN_basics.ipynb](DNN_basics.ipynb), design a simple neural network with `Dense` layers to detect Ag in the spectra. In general, it is not a straightforward task to determine the number of hidden layers and the number of neurons in each layer, which usually involves some trial and error. In this case, our output size is 1, so we'd better add a small layer before it, such as one with size 16; then we upscale the size from 16 to 64.

Next, compile the model. We can keep using `adam` for the `optimizer` and `['accuracy']` for the `metrics`. For the `loss`, since we are fitting to a range between 0 and 1, we can choose `binary_crossentropy`.


**Suggested Answer**

<details> <summary>Show / Hide</summary>
<p>
    
```python
class MLP(nn.Module):
    """ MLP model """

    def __init__(self, input_size, hidden_size, num_hidden_layers=1, activation=torch.relu):
        super(MLP, self).__init__()
        #self.input_layer = nn.Linear(input_size, 1000)
        #self.second_layer = nn.Linear(1000, 512)
        #self.third_layer = nn.Linear(512, hidden_size)
        self.input_layer = nn.Linear(input_size, hidden_size)
        self.hidden_layers = nn.ModuleList(
            [nn.Linear(hidden_size, hidden_size) for _ in range(num_hidden_layers)]
        )
        self.output_layer = nn.Linear(hidden_size, 3)
        self.act = activation
        default_init_to_he_uniform(self)

    def forward(self, x):
        x = self.input_layer(x)
        x = self.act(x)
 
        for layer in self.hidden_layers:
            x = layer(x)
            x = self.act(x)
        x = self.output_layer(x)
        x = torch.sigmoid(x)
        return x
```
    
</p>
</details>


### Helper functions

We define a few functions to help us calculate the accuracy of predictions and record the time of an epoch during training.

In [22]:
 def calculate_accuracy(y_pred, y, device):
    t = torch.Tensor([0.5]).to(device)  # threshold
    out = (y_pred > t).float() * 1
    acc = torch.sum(out == y)
    acc = acc / y.shape[0]
    return acc

 def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

### Train the model

Define a training and evaluation loop as in the lecture, then combine these and train this model for 10 epochs.


**Suggested Answer**

<details> <summary>Show / Hide</summary>
<p>
    
```python
def train(model, iterator, optimizer, criterion, device):

    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for (x, y) in tqdm(iterator, desc="Training", leave=False):

        x = x.to(device)                   # Move the data to the device where you want to compute
        y = y.to(device)

        optimizer.zero_grad()              # Initialise the optimiser
        y_pred = model(x)                  # Obtain initial predictions
        loss = criterion(y_pred, y)        # Calculate the loss
        acc = calculate_accuracy(y_pred, y)
        loss.backward()                    # Backprop the loss to update the weights
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion, device):

    epoch_loss = 0
    epoch_acc = 0

    model.eval()

    with torch.no_grad():

        for (x, y) in tqdm(iterator, desc="Evaluating", leave=False):

            x = x.to(device)
            y = y.to(device)

            y_pred = model(x)
            loss = criterion(y_pred, y)
            acc = calculate_accuracy(y_pred, y)

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

EPOCHS = 10

best_valid_loss = float('inf')
history = []

for epoch in trange(EPOCHS):

    start_time = time.monotonic()

    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'mlp-model.pt')

    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    history.append({'epoch': epoch, 'epoch_time': epoch_time,
                    'valid_acc': valid_acc, 'train_acc': train_acc})

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')
```
    
</p>
</details>

### Plot training history

Reuse the code from the lecture to plot the training stats. They will look bizarre at this stage, as explained in the forthcoming section.

**Suggested Answer**

<details> <summary>Show / Hide</summary>
<p>
    
```python
epochs = [x["epoch"] for x in history]
train_loss = [x["train_acc"] for x in history]
valid_loss = [x["valid_acc"] for x in history]

fig, ax = plt.subplots()
ax.plot(epochs, train_loss, label="train")
ax.plot(epochs, valid_loss, label="valid")
ax.set(xlabel="Epoch", ylabel="Acc.")
plt.legend()
```
    
</p>
</details>

## 2. Class imbalance

In the above history plot, notice how the accuracy of the model converges to a high value very quickly (>90% at the end of the first epoch). Such an odd history indicates that something could be wrong within our dataset.

### Data distribution

Let us inspect the distribution of the data using `plt.hist(train_y)`, paying special attention to the validation part (the final 20%).

In [None]:
# plot distribution of data
plt.figure(dpi=100)
plt.hist(train_y, label='Whole dataset')
plt.hist(train_y[-len(train_y)//5:], label='Validation subset')
plt.xticks([0, 1], ['0: no Ag', '1: with Ag'])
plt.xlabel('label')
plt.ylabel('number of data')
plt.legend()
plt.show()

The histograms show that our dataset is dominated by samples labelled 0 or "no Ag", which account for over 95% of the data. Thus, if the model simply learns to *guess* "no Ag" in every sample, it can achieve 95% accuracy without learning anything meaningful. This problem is known as **class imbalance**.

To avoid this, we must balance the classes. There are a number of strategies we can take:

* Upsample the minority class;
* Downsample the majority class;
* Change the performance metric.

The best available option for our problem is to downsample the majority class, which can be easily achieved with `numpy`:

In [26]:
# find original indices of 0 ('no Ag') and 1 ('with Ag')
id_no_Ag = np.where(train_y == 0)[0]
id_with_Ag = np.where(train_y == 1)[0]

# downsample 'no Ag' to the number of 'with Ag' by np.random.choice
id_no_Ag_downsample = np.random.choice(id_no_Ag, len(id_with_Ag))

# concatenate 'with Ag' and downsampled 'no Ag'
id_downsample = np.concatenate((id_with_Ag, id_no_Ag_downsample))

# shuffle the indices because they are ordered after concatenation
np.random.shuffle(id_downsample)

# finally get the balanced data
train_x_balanced = train_x[id_downsample]
train_y_balanced = train_y[id_downsample]

Re-exam the histograms of the balanced dataset after downsampling the majority:

In [None]:
# plot distribution of downsampled data
plt.figure(dpi=100)
plt.hist(train_y_balanced, label='Whole dataset')
plt.hist(train_y_balanced[-len(train_y_balanced)//5:], label='Validation subset')
plt.xticks([0, 1], ['0: no Ag', '1: with Ag'])
plt.xlabel('label')
plt.ylabel('number of data')
plt.legend()
plt.show()

### Re-train the model

Now we can re-train the model with the balanced dataset. Simply change `train_x` and `train_y` to `train_x_balanced` and `train_y_balanced`  and repeat all the steps in [1. Try out a network](#1.-Try-out-a-network). A larger `epochs` (say 1000) can be used because we now have much fewer data.


**Suggested Answer**

<details> <summary>Show / Hide</summary>
<p>
    
```python
##### to torch ######
x_in = torch.from_numpy(train_x_balanced).type(torch.float)
y_true = torch.from_numpy(train_y_balanced).type(torch.float).unsqueeze(1)

### split into train/val ######

train_lim = int(x_in.shape[0] * 0.8)
x_in_train = x_in[:train_lim]
y_true_train = y_true[:train_lim]

x_in_val = x_in[train_lim:]
y_true_val = y_true[train_lim:]

train_data = [(x_in_train[i], y_true_train[i]) for i in range(x_in_train.shape[0])]
val_data = [(x_in_val[i], y_true_val[i]) for i in range(x_in_val.shape[0])]

BATCH_SIZE = 512

train_iterator = data.DataLoader(train_data,
                                 shuffle=True,
                                 batch_size=BATCH_SIZE)

valid_iterator = data.DataLoader(val_data,
                                 batch_size=BATCH_SIZE)

EPOCHS = 100

best_valid_loss = float('inf')
history = []

for epoch in trange(EPOCHS):

    start_time = time.monotonic()

    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'mlp-model.pt')

    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    history.append({'epoch': epoch, 'epoch_time': epoch_time,
                    'valid_acc': valid_acc, 'train_acc': train_acc})

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')
```
    
</p>
</details>

---

## Exercises

Build a DNN to detect the presence of all the elements. To do this, you may go through the following steps:

1. Find all the elements appearing in the dataset; the answer will be `['Zn', 'Sb', 'Si', 'Fe', 'Ag', 'Cu', 'Bi']`.
2. Balance the dataset: if one of the elements appears much less times than the others, it is better to ignore it. Doing everything correctly, you will find the number of samples containing each element as shown in the following table. Therefore, we may ignore Ag in this network.


|  Element | # Samples |
|---|---|
|Zn| 51174|  
|Sb| 51132|  
|Si| 50909|
|Fe| 50764|
|Ag| 10000|
|Cu| 50945|
|Bi| 50784|
    
3. Do one-hot encoding for the element list `['Zn', 'Sb', 'Si', 'Fe', 'Cu', 'Bi']`; if a sample contains Fe and Sb, e.g., the one-hot vector for this sample will be `[0, 1, 0, 1, 0, 0]`.
4. Build and train a DNN (with an output size of 6) to detect the presence of the six elements.

If doing everything correctly, you will find that the overall accuracy is around 60%. However, the model is not garbage. If we evaluate the accuracy for each element, we will find that the accracy for some of elements is nearly 0 while for the others nearly 100%. This means the dataset is agnostic to these elements, which lower the overall accuracy, but the model can still be used to predict the other elements with  high accuracy.

**Suggested Answer**

<details> <summary>Show / Hide</summary>
<p>
    
n/a TODO
    
</p>
</details>