# CUR matrix decomposition for constructing PaiNN model
This code introduces how to use CUR matrix decomposition for active learning. We use <a href=https://aip.scitation.org/doi/10.1063/1.3553717>symmetry function</a> as the fingerprints of atoms and select most representative atomic environments by using CUR. The most important images (contain important atoms) will be used for labelling.

Note that the procedures can also used for any other model systems and different fingerprints.

## Generate symmetry functions (fingerprints of atomic environments)
Firstly, we need several helpler functions to calculate the values of symmetry functions. Several symmetry functions are created, namely G2, G3, G4, and G5 functions (see definitions in <a href=https://compphysvienna.github.io/n2p2/>n2p2</a>).

In [1]:
import torch
from torch import nn
from collections import defaultdict
from torch_scatter import scatter_min

def pair_cutoff(pair_dist, R_c, cutoff_type=2):
    if cutoff_type == 1:
        fc = 0.5 * (torch.cos(pair_dist * torch.pi / R_c) + 1)
    elif cutoff_type == 2:
        fc = torch.tanh(1 - pair_dist / R_c) ** 3
    
    return fc

def G2_SF(tensors, j_elem, eta, R_s, R_c):
    eta = eta.to(device=tensors['n_diff'].device)
    R_s = R_s.to(device=tensors['n_diff'].device)
    R_c = R_c.to(device=tensors['n_diff'].device)
    pair_i_idx = tensors['atom_i_idx']
    num_atoms = tensors['image_idx'].shape[0]
        
    # find the right neighbors
    j_mask = (tensors['j_elems'] == j_elem)
    sfs = torch.zeros((num_atoms, R_c.shape[0]), device=tensors['n_diff'].device) 
    
    # calculate symmetry function values
    pair_dist = tensors['n_dist'][j_mask].unsqueeze(dim=-1)
#    pair_fc = 0.5 * (torch.cos(pair_dist * torch.pi / R_c) + 1)
    pair_fc = pair_cutoff(pair_dist, R_c, cutoff_type=2)
    pair_sfs = torch.exp(-eta * (pair_dist - R_s) ** 2) * pair_fc
    sfs.index_add_(0, pair_i_idx[j_mask], pair_sfs)
    
    return sfs

def G3_SF(tensors, j_elem, kappa, R_c):
    kappa = kappa.to(device=tensors['n_diff'].device)
    R_c = R_c.to(device=tensors['n_diff'].device)
    pair_i_idx = tensors['atom_i_idx']
    num_atoms = tensors['image_idx'].shape[0]
        
    # find the right neighbors
    j_mask = (tensors['j_elems'] == j_elem)
    sfs = torch.zeros((num_atoms, R_c.shape[0]), device=tensors['n_diff'].device) 
    
    # calculate symmetry function values
    pair_dist = tensors['n_dist'][j_mask].unsqueeze(dim=-1)
#    pair_fc = 0.5 * (torch.cos(pair_dist * torch.pi / R_c) + 1)
    pair_fc = pair_cutoff(pair_dist, R_c, cutoff_type=2)
    pair_sfs = torch.cos(kappa * pair_dist) * pair_fc
    sfs.index_add_(0, pair_i_idx[j_mask], pair_sfs)
    
    return sfs

def G5_SF(tensors, j_elem, k_elem, zeta, Lambda, eta, R_c):
    diff = tensors['n_diff']
    dist = tensors['n_dist']   
    pair_i_idx = tensors['atom_i_idx']
    num_atoms = tensors['image_idx'].shape[0]
    
    zeta = zeta.to(device=diff.device)
    Lambda = Lambda.to(device=diff.device)
    eta = eta.to(device=diff.device)
    R_c = R_c.to(device=diff.device)
    
    j_mask = (tensors['j_elems'] == j_elem)
    k_mask = (tensors['j_elems'] == k_elem)
    
    # get relative index of neighbors
    atom_idx_j_masked, j_inv_idx, j_counts = torch.unique_consecutive(
        pair_i_idx[j_mask], return_inverse=True, return_counts=True,
    )
    atom_idx_k_masked, k_inv_idx, k_counts = torch.unique_consecutive(
        pair_i_idx[k_mask], return_inverse=True, return_counts=True,
    )
    
    g_idx = torch.arange(j_inv_idx.shape[0], device=diff.device)
    idx_min, _ = scatter_min(g_idx, j_inv_idx)
    j_ridx = g_idx - idx_min[j_inv_idx]

    g_idx = torch.arange(k_inv_idx.shape[0], device=diff.device)
    idx_min, _ = scatter_min(g_idx, k_inv_idx)
    k_ridx = g_idx - idx_min[k_inv_idx]
    
    # get the matrix of M * N * ..., 
    # where M is the number of center atoms, 
    # N is the maximum number of their j or k neighbors
    diff_ij = torch.zeros((num_atoms, j_counts.max(), 3), device=diff.device)
    diff_ik = torch.zeros((num_atoms, k_counts.max(), 3), device=diff.device)
    dist_ij = torch.zeros((num_atoms, j_counts.max()), device=diff.device)
    dist_ik = torch.zeros((num_atoms, k_counts.max()), device=diff.device)
    
    diff_ij[pair_i_idx[j_mask], j_ridx] = diff[j_mask]
    diff_ik[pair_i_idx[k_mask], k_ridx] = diff[k_mask]
    dist_ij[pair_i_idx[j_mask], j_ridx] = dist[j_mask]
    dist_ik[pair_i_idx[k_mask], k_ridx] = dist[k_mask]
    
    # calculate the values of different parts in angular symmetry functions
    # Do remember to revise part 1！！！！！！！！！！ This implementaion will result in NaN gradient problem!!!
    diff_ijk = torch.einsum("ijk, ilk -> ijl", diff_ij, diff_ik)
    dist_prod = (dist_ij.unsqueeze(dim=-1) * dist_ik.unsqueeze(dim=-2))
    
    # handling situation that j = k
    if j_elem == k_elem:
        dist_prod = torch.triu(dist_prod, diagonal = 1)   # this place is true
        
    idx_i, idx_j, idx_k = torch.where(dist_prod) 

    part_1 = diff_ijk[idx_i, idx_j, idx_k] / dist_prod[idx_i, idx_j, idx_k]
    part_1 = (part_1.unsqueeze(dim=1) * Lambda + 1) ** zeta

    part_2 = torch.exp(-eta * (dist_ij[idx_i, idx_j] ** 2 + dist_ik[idx_i, idx_k] ** 2).unsqueeze(dim=-1))

    pair_fc_ij = pair_cutoff(dist_ij[idx_i, idx_j].unsqueeze(dim=-1), R_c, cutoff_type=2)
    pair_fc_ik = pair_cutoff(dist_ik[idx_i, idx_k].unsqueeze(dim=-1), R_c, cutoff_type=2)
    part_3 = pair_fc_ij * pair_fc_ik
    sfs = torch.zeros((num_atoms, R_c.shape[0]), device=diff.device)
    sfs = sfs.index_add(0, idx_i, part_1 * part_2 * part_3 * 2 ** (1-zeta))
    
    return sfs

def G4_SF(tensors, j_elem, k_elem, zeta, Lambda, eta, R_c):
    diff = tensors['n_diff']
    dist = tensors['n_dist']   
    pair_i_idx = tensors['atom_i_idx']
    num_atoms = tensors['image_idx'].shape[0]
    
    zeta = zeta.to(device=diff.device)
    Lambda = Lambda.to(device=diff.device)
    eta = eta.to(device=diff.device)
    R_c = R_c.to(device=diff.device)
    
    j_mask = (tensors['j_elems'] == j_elem)
    k_mask = (tensors['j_elems'] == k_elem)
    
    # get relative index of neighbors
    atom_idx_j_masked, j_inv_idx, j_counts = torch.unique_consecutive(
        pair_i_idx[j_mask], return_inverse=True, return_counts=True,
    )
    atom_idx_k_masked, k_inv_idx, k_counts = torch.unique_consecutive(
        pair_i_idx[k_mask], return_inverse=True, return_counts=True,
    )
    
    g_idx = torch.arange(j_inv_idx.shape[0], device=diff.device)
    idx_min, _ = scatter_min(g_idx, j_inv_idx)
    j_ridx = g_idx - idx_min[j_inv_idx]

    g_idx = torch.arange(k_inv_idx.shape[0], device=diff.device)
    idx_min, _ = scatter_min(g_idx, k_inv_idx)
    k_ridx = g_idx - idx_min[k_inv_idx]
    
    # get the matrix of M * N * ..., 
    # where M is the number of center atoms, 
    # N is the maximum number of their j or k neighbors
    diff_ij = torch.zeros((num_atoms, j_counts.max(), 3), device=diff.device)
    diff_ik = torch.zeros((num_atoms, k_counts.max(), 3), device=diff.device)
    dist_ij = torch.zeros((num_atoms, j_counts.max()), device=diff.device)
    dist_ik = torch.zeros((num_atoms, k_counts.max()), device=diff.device)
    
    diff_ij[pair_i_idx[j_mask], j_ridx] = diff[j_mask]
    diff_ik[pair_i_idx[k_mask], k_ridx] = diff[k_mask]
    dist_ij[pair_i_idx[j_mask], j_ridx] = dist[j_mask]
    dist_ik[pair_i_idx[k_mask], k_ridx] = dist[k_mask]
    
    # calculate the values of different parts in angular symmetry functions
    # Do remember to revise part 1！！！！！！！！！！ This implementaion will result in NaN gradient problem!!!
    diff_ijk = torch.einsum("ijk, ilk -> ijl", diff_ij, diff_ik)
    dist_prod = (dist_ij.unsqueeze(dim=-1) * dist_ik.unsqueeze(dim=-2))
    
    # handling situation that j = k
    if j_elem == k_elem:
        dist_prod = torch.triu(dist_prod, diagonal = 1)   # this place is true
        
    idx_i, idx_j, idx_k = torch.where(dist_prod)
    pair_dist_jk = torch.norm(diff_ik[idx_i, idx_k] - diff_ij[idx_i, idx_j], dim=-1)
    jk_mask = pair_dist_jk.unsqueeze(-1) < R_c[0]
    
    part_1 = diff_ijk[idx_i, idx_j, idx_k] / dist_prod[idx_i, idx_j, idx_k]
    part_1 = (part_1.unsqueeze(dim=1) * Lambda + 1) ** zeta

    part_2 = torch.exp(-eta * (dist_ij[idx_i, idx_j] ** 2 + dist_ik[idx_i, idx_k] ** 2).unsqueeze(dim=-1))

    pair_fc_ij = pair_cutoff(dist_ij[idx_i, idx_j].unsqueeze(dim=-1), R_c, cutoff_type=2)
    pair_fc_ik = pair_cutoff(dist_ik[idx_i, idx_k].unsqueeze(dim=-1), R_c, cutoff_type=2)
    pair_fc_jk = pair_cutoff(pair_dist_jk.unsqueeze(dim=-1), R_c, cutoff_type=2)
    part_3 = pair_fc_ij * pair_fc_ik * pair_fc_jk

    sfs = torch.zeros((num_atoms, R_c.shape[0]), device=diff.device)
    sfs = sfs.index_add(0, idx_i[jk_mask], (part_1 * part_2 * part_3 * 2 ** (1-zeta))[jk_mask])
    
    return sfs

bp_sf_fns = {'G2': G2_SF, 'G3': G3_SF, 'G4': G4_SF, 'G5': G5_SF}
class BPSymmFunc:
    """
    Get Behler-Parrinello style symmetry function values of atoms
    """
    def __init__(self, sf_spec, compute_forces=True):
        self.sf_spec = defaultdict(list)
        self.compute_forces = compute_forces
        for elem, elem_spec in sf_spec.items():
            for spec in elem_spec:
                fn = bp_sf_fns[spec['type']]
                options = {k:torch.FloatTensor(v)
                           if isinstance(v, list) else v 
                           for k, v in spec.items() if k != 'type'}
                self.sf_spec[elem].append((fn, options))
    
    def __call__(self, tensors):        
        # preprocess
        num_atoms = tensors['num_atoms']
        num_pairs = tensors['num_pairs']
        
        ## pair offset
        pairs = tensors['pairs']
        pair_offset = torch.cumsum(
            torch.cat((torch.tensor([0], 
                                    device=num_atoms.device,
                                    dtype=num_atoms.dtype,                                    
                                   ), num_atoms[:-1])),
            dim=0
        )
        pair_offset = torch.repeat_interleave(pair_offset, num_pairs)
        pairs = pairs + pair_offset.unsqueeze(-1)
        
        ## get atom image index
        image_idx = torch.arange(tensors['num_atoms'].shape[0],
                                 device=pairs.device,
                                )
        image_idx = torch.repeat_interleave(image_idx, num_atoms)
        
        ## calculate distance
        n_dist = torch.linalg.norm(tensors['n_diff'], dim=1)

        fps = {}
        for elem, elem_spec in self.sf_spec.items():    
            sfs = []
            i_elem = atomic_numbers[elem]
            i_masked = self.get_i_masked(tensors, pairs, image_idx, n_dist, i_elem=i_elem)
            for fn, options in elem_spec:
                sf = fn(i_masked, **options)
                sfs.append(sf)
            sfs =  torch.hstack(sfs)
            fps[elem] = {
                'sfs': sfs.detach().cpu().numpy(),
                'image_idx': i_masked['image_idx'].detach().cpu().numpy(),
            }
            
        return fps
    
    def get_i_masked(self, tensors, pairs, image_idx, n_dist, i_elem):
        """This function aims to construct M * N matrices, 
        where M is the number of selected center atoms,
        N is the maximum number of these atoms' neighbors.
        """        
        i_mask = (tensors['elems'][pairs[:, 0]] == i_elem)
        _, inverse_indices, counts = torch.unique_consecutive(pairs[:, 0][i_mask],
                                                              return_inverse=True,
                                                              return_counts=True)
    
        i_masked = {
            'image_idx': image_idx[tensors['elems'] == i_elem],
            'atom_i_idx': inverse_indices,                        # atom indices of i masked pairs in this batch           
            'n_dist': n_dist[i_mask],                             # distances of i masked pairs
            'n_diff': tensors['n_diff'][i_mask],                  # distance vectors of i masked pairs
            'j_elems': tensors['elems'][pairs[:, 1]][i_mask],     # j elements of pairs
        }
        return i_masked

We write a function to read the defined symmetry function parameters. The format of <font color = blue>input.nn</font> is the same to <a href=https://www.uni-goettingen.de/de/560580.html>RuNNer</a> and <a href=https://compphysvienna.github.io/n2p2/>n2p2</a>.

In [2]:
import re
from ase.data import atomic_numbers

n2p2_sf_type = {'2': 'G2', '3': 'G4', '9': 'G5'}
def get_sf_dict(input_file = 'input.nn'):
    """ Read symmetry functions parameters from input.nn """
    lines = []
    for line in open(input_file, 'r'):
        if line.startswith('symfunc'):
            lines.append(line.strip())
    sf_specs = {}
    for line in lines:
        line = line.split()
        if re.match('[a-zA-Z]', line[4]):
            try:
                sf_specs[" ".join(line[1:5])]
            except:
                sf_specs[" ".join(line[1:5])] = {'eta': [], 'Lambda': [], 'zeta': [], 'R_c': []}
            finally:
                sf_specs[" ".join(line[1:5])]['eta'].append(float(line[5]))
                sf_specs[" ".join(line[1:5])]['Lambda'].append(float(line[6]))
                sf_specs[" ".join(line[1:5])]['zeta'].append(float(line[7]))
                sf_specs[" ".join(line[1:5])]['R_c'].append(float(line[8]))
            
        else:
            try:
                sf_specs[" ".join(line[1:4])]
            except:
                sf_specs[" ".join(line[1:4])] = {'eta': [], 'R_s': [], 'R_c': []}
            finally:
                sf_specs[" ".join(line[1:4])]['eta'].append(float(line[4]))
                sf_specs[" ".join(line[1:4])]['R_s'].append(float(line[5]))
                sf_specs[" ".join(line[1:4])]['R_c'].append(float(line[6]))

    new_sf_specs = {}
    for k1, v1 in sf_specs.items():
        sf_type = k1.split()
        sf_dict = {'type': n2p2_sf_type[sf_type[1]], 'j_elem': atomic_numbers[sf_type[2]]}
        sf_dict.update({k2: v2 for k2, v2 in v1.items()})
        if len(sf_type) == 4:
            sf_dict.update({'k_elem': atomic_numbers[sf_type[3]]})
        try:
            new_sf_specs[sf_type[0]]
        except:
            new_sf_specs[sf_type[0]] = []
        finally:
            new_sf_specs[sf_type[0]].append(sf_dict)
            
    return new_sf_specs

In [3]:
sf_specs = get_sf_dict()

## Load dataset and calculate atomic importance

We use a test trajectory to show how CUR works.

In [4]:
from PaiNN.data import AseDataset, collate_atomsdata
dataset = AseDataset('../demo.traj', cutoff=6.0)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,
    collate_fn=collate_atomsdata,
)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
symfunc = BPSymmFunc(sf_specs)

In [7]:
# collect fingerpoints
import numpy as np
fps = {}
batch_num = 0
for device_batch in dataloader:
    batch = {
        k: v.to(device=device, non_blocking=True)
        for (k, v) in device_batch.items()
    }
    batch_fps = symfunc(batch)
    for elem in batch_fps.keys():
        batch_fps[elem]['image_idx'] += batch_num
        if fps.get(elem):
            for k, v in batch_fps[elem].items():
                fps[elem][k].append(v)
        else:
            fps[elem] = {k: [v] for k, v in batch_fps[elem].items()}

    batch_num += len(batch['num_atoms'])

for elem in fps.keys():
    for k, v in fps[elem].items():        
        fps[elem][k] = np.concatenate(v)
        
    # shuffle different atoms
    p = np.random.permutation(len(fps[elem][k]))
    fps[elem]['index'] = p
    fps[elem]['sfs'] = fps[elem]['sfs'][p]
    fps[elem]['image_idx'] = fps[elem]['image_idx'][p]

In [None]:
import os
from CUR import div_mat, CUR_decomposition_row
from multiprocessing import Pool
mat_num = 500
# processors = os.cpu_count()
processors = 10

for elem in fps.keys():
    split_sfs = div_mat(fps[elem]['sfs'], mat_num)
    with Pool(processors) as pool:
        W_mat, C_arr = [], []
        for W, C in pool.map(CUR_decomposition_row, split_sfs):
            W_mat.append(W)
            C_arr.append(C)
    W_mat = np.concatenate(W_mat, axis=1)
    fps[elem]['atomic_importance'] = np.linalg.norm(W_mat, axis=0)   

## Add up atomic importance to get image importance
By now we obtained the importance score for each atom. We then need to add up the importance score for each image

In [11]:
num_selected = 100              # number of configurations we need to select from an iteration
image_importance = np.zeros(len(dataset))

for elem in fps.keys():
    np.add.at(image_importance, fps[elem]['image_idx'], fps[elem]['atomic_importance'])
selected_indices = np.argsort(image_importance)[-100:]

Then we can label the representative images by DFT after getting their indices.