# ARPESNet compression appied to Clustering simulated nanoARPES data

In this notebook we will explore an application of the compression provided by ARPESNet. We will use a simulated nanoARPES dataset, which is a 3D array of ARPES spectra, and apply ARPESNet to compress it. We will then use the compressed data to perform clustering and compare the results to clustering performed on the original data.

In [None]:
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.cluster import KMeans
from tqdm.auto import tqdm

import arpesnet as an

print("Python", sys.version)
GPU_ENABLED = torch.backends.mps.is_available() or torch.cuda.is_available()
print(f"Pytorch version: {torch.__version__} | GPU enabled = {GPU_ENABLED}")

# load data
### Set the path to the data directory
the `root` variable should be set to the path of the directory containing the data files. it should contain:
- cluster_centers.pt

we start by loading 5 spectra from a single material. 

In [None]:
root = Path(r"path/to/your/data/folder/")
assert root.exists()
cluster_centers = torch.load(root/"cluster_centers.pt")

In [None]:
fig, ax = plt.subplots(2,len(cluster_centers),figsize=(6,3))

ref = cluster_centers[2]
for i,img in enumerate(cluster_centers):
    # img = tr.pipe(cut.values,transform_nonoise).numpy()
    ax[0,i].imshow(img.numpy(), cmap='viridis', interpolation='none',origin='lower',aspect='equal')
    ax[0,i].set_title(f'{i}')
    ax[0,i].set_xlabel('ky')
    ax[0,i].set_ylabel('E')
    #tu0,rn off axis
    ax[0,i].axis('off')
    diff = img-ref
    ax[1,i].imshow(diff.numpy(), cmap='bwr', clim=(-100,100), interpolation='none',origin='lower',aspect='equal')
    ax[1,i].axis('off')



# add poissonian noise
We add poissonian noise to the data to simulate the noise in real ARPES data. `n_counts` is the average number of counts per spectra in the noisy data, which simulates acquiring data for shorter times.

In [None]:
n_counts = 10_000
noiser = an.transform.SetRandomPoissonExposure(n_counts)

In [None]:
fig, ax = plt.subplots(2,len(cluster_centers),figsize=(6,3))
noisy = [noiser(img) for img in cluster_centers]
ref = noisy[2]
for i,img in enumerate(noisy):
    ax[0,i].imshow(img.numpy(), cmap='viridis', interpolation='none',origin='lower',aspect='equal')
    ax[0,i].set_title(f'{i}')
    ax[0,i].set_xlabel('ky')
    ax[0,i].set_ylabel('E')
    #tu0,rn off axis
    ax[0,i].axis('off')
    diff = img-ref
    vmax = np.max(np.abs(img.numpy()))
    ax[1,i].set_title(f'N: {img.sum()}')
    ax[1,i].imshow(diff.numpy(), cmap='bwr', clim=(-vmax,vmax), interpolation='none',origin='lower',aspect='equal')
    ax[1,i].axis('off')



# create a ground truth map
We create a ground truth map for the data, which will be used to evaluate the clustering results. The ground truth map is a 2D array with the same shape as the data, where each pixel has a unique integer value. The pixels with the same value belong to the same cluster. Then, we assign spectra with random noise to each pixel, so that the spectra in the same cluster (of the ground truth map) originate from the same spectrum, but have different noise patterns.

In [None]:
arr = np.zeros((50,50))
arr[:,:10] = 0
arr[:,10:20] = 1
arr[:,20:30] = 2
arr[:,30:40] = 3
arr[:,40:] = 4
arr = arr.T
plt.figure(figsize=(2,2))
plt.imshow(arr,cmap="RdBu", interpolation='none')
ground_truth = arr.astype(np.int64)
n_clusters = 5

### Assign spectra to each pixel

In [None]:
flatmap = []
for i in tqdm(
    ground_truth.ravel(),
    total=len(ground_truth.ravel()),
    desc="generate noisy map",
):
    flatmap.append(noiser(cluster_centers[i]))
flatmap = torch.stack(flatmap).to(torch.float32)


In [None]:
plt.figure(figsize=(2,2))
intensity_map = flatmap.sum((1,2)).reshape(50,50)
plt.imshow(intensity_map,cmap='RdBu',interpolation='none')
print(intensity_map.mean())

# ARPESNet

## load trained model
We load a pre-trained ARPESNet model. The model was trained on a dataset of ARPES spectra from various materials. The model is a convolutional autoencoder, which compresses the input spectra into a lower-dimensional representation and then reconstructs the input spectra from the compressed representation.

In [None]:
arpesnet = an.load_trainer("../trained_model/arpesnet_n2n_4k.pth")

# test reconstruction of noisy data
We test the reconstruction of the noisy data by passing it through the ARPESNet model. We compare the original noisy data with the reconstructed data.

In [None]:
fig, ax = plt.subplots(3,len(cluster_centers),figsize=(6,3))
noisy = [noiser(img) for img in cluster_centers]
reconstructed = [arpesnet.eval(img) for img in noisy]
ref = reconstructed[2]
for i, img, rec in zip(range(len(noisy)), noisy, reconstructed):
    ax[0,i].imshow(img.numpy(), cmap='viridis', interpolation='none',origin='lower',aspect='equal')
    ax[0,i].set_title(f'{i}')
    ax[0,i].set_xlabel('ky')
    ax[0,i].set_ylabel('E')
    #tu0,rn off axis
    ax[0,i].axis('off')
    ax[1,i].imshow(rec.numpy(), cmap='viridis', interpolation='none',origin='lower',aspect='equal')
    ax[1,i].axis('off')
    diff = rec-ref
    vmax = np.max(np.abs(img.numpy()))
    ax[2,i].imshow(diff.numpy(), cmap='bwr', clim=(-vmax,vmax), interpolation='none',origin='lower',aspect='equal')
    ax[2,i].axis('off')



# encode the noisy map data
To prepare for clustering, we encode the noisy map data using the ARPESNet model. We pass the noisy map data through the encoder part of the ARPESNet model to obtain the compressed representation of the data.

In [None]:
encoded = torch.stack([arpesnet.encode(img).cpu().detach().squeeze().flatten() for img in flatmap])
encoded.shape

# clustering
We perform clustering on the compressed data using k-means clustering. We choose the number of clusters to be the same as the number of clusters in the ground truth map.

In [None]:
from itertools import permutations

def remap_labels(pred_labels, true_labels) -> tuple:
    """Rename prediction labels (clustered output) to best match true labels."""
    pred_labels, true_labels = np.array(pred_labels), np.array(true_labels)
    assert pred_labels.ndim == 1 == true_labels.ndim
    assert len(pred_labels) == len(true_labels)
    cluster_names = np.unique(pred_labels)
    accuracy = 0

    perms = np.array(list(permutations(np.unique(true_labels))))

    remapped_labels = true_labels
    for perm in perms:
        flipped_labels = np.zeros(len(true_labels))
        for label_index, label in enumerate(cluster_names):
            flipped_labels[pred_labels == label] = perm[label_index]

        testAcc = np.sum(flipped_labels == true_labels) / len(true_labels)
        if testAcc > accuracy:
            accuracy = testAcc
            remapped_labels = flipped_labels

    return accuracy, remapped_labels

## compute kmeans
We compute k-means clustering on the compressed data.

In [None]:
km = KMeans(n_clusters=n_clusters, n_init=100)
result =  km.fit(encoded)
sorted_labels = remap_labels(result.labels_, ground_truth.ravel())[1].reshape(50,50)


In [None]:
fig,ax = plt.subplots(1,2,figsize=(5,3))
for a in ax:
    a.axis('off')
ax[0].imshow(ground_truth, cmap='RdBu', interpolation='none')
ax[0].set_title('ground truth')
ax[1].imshow(sorted_labels, cmap='RdBu', interpolation='none')
ax[1].set_title('clsutering result')
accuracy = np.sum(ground_truth == sorted_labels) / (50*50)
plt.suptitle(f'clustering with {n_counts:,.0f} counts | Accuracy={accuracy:.2%}');

# clusering uncompressed data
Let's now compare with clustering of uncompressed data. This will take longer to run, about a few minutes depending on your hardware.

In [None]:
result_raw = km.fit(flatmap.view(50*50,-1).cpu().numpy())
sorted_labels_raw = remap_labels(result_raw.labels_, ground_truth.ravel())[1].reshape(50,50)

In [None]:
fig,ax = plt.subplots(1,3,figsize=(10,4))
for a in ax:
    a.axis('off')
ax[0].imshow(ground_truth, cmap='RdBu', interpolation='none')
ax[0].set_title('ground truth')
ax[1].imshow(sorted_labels_raw, cmap='RdBu', interpolation='none')
ax[1].set_title(f'raw: {np.sum(ground_truth == sorted_labels_raw) / (50*50)}')
ax[2].imshow(sorted_labels, cmap='RdBu', interpolation='none')
ax[2].set_title(f'ARPESNet: {np.sum(ground_truth == sorted_labels) / (50*50)}')

accuracy = np.sum(ground_truth == sorted_labels) / (50*50)
plt.suptitle(f'clustering with {n_counts:,.0f} counts');