# Aims
* demonstrate a simple implementation of the \\( W^TP_l\\) compression
* show that this compression can be "undone"

In [43]:
import quippy
from quippy import descriptors


import ase
from ase.io import read, write
import scipy
import periodictable as PT
import random

import os
import glob
import zipfile
import urllib.request
import numpy as np
import pprint

pp = pprint.PrettyPrinter(indent=2)

np.set_printoptions(precision=3)

from matplotlib import pyplot as plt
%matplotlib notebook

# Helper functions to setup SOAP params
Note: These are not necessarily "good" parameters as this is just a toy example

In [44]:
def dataset_to_params(dataset):
    """Help function to construc a dictionary of SOAP params from a given dataset"""
    species = set()
    for s in dataset:
        for atom in s:
            species.add(atom.symbol)


    params = {}
    Zs = [PT.__dict__[el].number for el in species]

    Zs = sorted(Zs, key = lambda x: x)
    params["n_Z"] = len(species)
    params["n_species"] = len(species)
    Zstr = "{"
    for Z in Zs:
        Zstr += str(Z) + " "
    all_Zs = Zs
    Zstr = Zstr[:-1]+"}"
    params["Z"] = Zstr
    params["species_Z"] = Zstr
    params["n_max"] = 6
    params["l_max"] = 3
    params["soap cutoff"] = 5
    params["atom_sigma"] = 0.4
    return params

def params2quippy_str(params):
    """helper function to convert a dictionary of params to string format"""
    qs = ""
    for key, val in params.items():
        qs += str(key) + "=" + str(val) + " "
    return qs

# Compression Functions

In [45]:
def compress_ps(ps, params, random_key, norm=True):
    """compress the powerspectrum using W^T P_l"""
    
    N = round(params["n_max"])
    L = round(params["l_max"])
    S = round(params["n_Z"])
    
    
    #1. split up into NxN matrices each with fixed l
    l_slices = split_ps2(ps, N*S, L)
    
    #2. compress each matrix using the random_key
    short = []
    for l in range(0, L+1):
        Pl = l_slices[l]
        W = random_key[:,:2*l+1]
        short += list(np.dot(W.T, Pl).flatten())
    short = np.array(short)
    if norm:
        short = short/np.linalg.norm(short)
    return short
    
def ps2slices(ps, NS, L):
    """Splits the power spectrum into an (L+1, NS, NS) numpy array of l-slices"""
    
    #1. reshape the power spectrum
    l_slices = np.zeros((L+1, NS, NS))
    i = 0
    c = 2**-0.5
    for n1 in range(0, NS):
        for n2 in range(0, n1+1):
            for l in range(0, L+1):
                p = ps[(L+1)*i+l]
                if n1 == n2:
                    l_slices[l][n1][n2] = p
                else:
                    p = p * c
                    l_slices[l][n1][n2] = p
                    l_slices[l][n2][n1] = p
            i += 1 
    return l_slices

def gen_random_key(params):
    """Generate a 'random key' to use in compression"""
    N = round(params["n_max"])
    L = round(params["l_max"])
    S = round(params["n_Z"])
    ncols = min([N*S, 2*L+1])
    nrows = N*S
    return 0.1 + 0.8 *np.random.sample(size=(nrows, ncols))
 

# Function to undo Compression

In [46]:
def uncompress_slice(WtP, l, RK):
    """take in a compressed slice and the random key and uncompress"""

    def get_IR(A):
        """find and return linearly independent rows of A
        also return indicies of those rows"""
        T = []
        inds = []
        cur_rank = 0
        for i, row in enumerate(A):
            T.append(row)
            rank = np.linalg.matrix_rank(T)
            if rank > cur_rank:
                cur_rank = rank
                inds.append(i)
            else:
                T = T[:-1]

        #Think this check goes awry sometimes due to numerical issues
        #if this is happening then don't need the missing rows
        #assert np.linalg.matrix_rank(T) == np.linalg.matrix_rank(A)
        return T, inds


    W = RK[:, :2*l+1]
    #get r equations, where r = rank(Pl) == rank(WtP) <= 2*l+1
    T, inds = get_IR(WtP)
    
    #sometimes WtP can be exactly zero by symmetry -> return zeros
    if len(inds) == 0:
        assert np.linalg.matrix_rank(WtP) == 0
        NS = W.shape[0]
        return np.zeros((NS,NS))

    
    #reduced rows
    C = np.dot(WtP, W)
    C = C[inds,:][:,inds]
    
    #find eigenvalues and check the reconsruction
    vals, Ut = np.linalg.eigh(C)
    U = Ut.T
    Lambda = np.diag(vals)
    R = np.dot(U.T, np.dot(Lambda, U))
    assert np.allclose(R, C)
    
    #Solve for original Pl
    roots = [x**-0.5 for x in vals]
    #Truncate to avoid numerical issues, using relative size of 1e-10 for truncation
    min_pos = min([x for x in roots if x > 0])
    #Gram matricies should be positive semi-definite
    roots = [x if x > 0 else min_pos*1e-15 for x in roots]
    D = np.diag(roots)
    X = np.dot(D, np.dot(U, T))
    P = np.dot(X.T, X)
    assert np.allclose(np.dot(W.T, P), WtP)
  
    return P

# Download the Li-TM dataset
Artrith, Nongnuch, Alexander Urban, and Gerbrand Ceder. "Efficient and accurate machine-learning interpolation of atomic energies in compositions with many species." Physical Review B 96.1 (2017): 014112.
https://journals.aps.org/prb/abstract/10.1103/PhysRevB.96.014112


In [47]:
%%time
if not os.path.exists("datasets/Li-TM.zip"):
    print("downloading Li-TM dataset")
    url = "https://journals.aps.org/prb/supplemental/10.1103/PhysRevB.96.014112/LiMO2-reference-data-11-species.zip"
    urllib.request.urlretrieve(url, "datasets/Li-TM.zip")

if not os.path.exists("datasets/LiMO2-reference-data-11-species/xsf-files/structure10415.xsf"):
    print("unzipping")
    with zipfile.ZipFile("datasets/Li-TM.zip", 'r') as zip_ref:
        zip_ref.extractall("datasets/")

downloading Li-TM dataset
unzipping


KeyboardInterrupt: 

# Read in the dataset

In [50]:
xsf_files = glob.glob("datasets/LiMO2-reference-data-11-species/xsf-files/*.xsf")
random.shuffle(xsf_files)
dataset = [read(f) for f in xsf_files[:100]]

# Compress the power spectrum then undo the compression
* convert the power spectrum to (L+1, NS, NS) array of "l-slices", and leave in this form for convienience
* check reconstruction for 100 structures
* It is possible for the reconstruction to fail if \\( \text{rank}(W^TP_l) < \text{rank}(P_l) \\) due to an unfortunate choice of \\( W\\)

In [51]:
params = dataset_to_params(dataset)
qs = params2quippy_str(params)
N = round(params["n_max"])
L = round(params["l_max"])
S = round(params["n_Z"])
RK = gen_random_key(params)

for i, struc in enumerate(dataset):
    #compute regular power spectrum
    desc = descriptors.Descriptor(qs)
    output = desc.calc(struc)
    ps = output["data"][0][:-1]
    assert len(ps) == round(0.5*N*S*(N*S+1)*(L+1))
    full_slices = ps2slices(ps, N*S, L)
    ps2 = full_slices.flatten()
    assert abs(np.linalg.norm(ps2)-1) < 1e-6

    #Compress
    comp_slices = [np.dot(RK[:,:2*l+1].T, Pl) for l, Pl in enumerate(full_slices)]
    n_short = sum([len(x.flatten()) for x in comp_slices])
    assert n_short == N*S*(L+1)**2

    #undo compression, using only compressed power spectrum and 'random key'
    recon_slices = np.array([uncompress_slice(WtP, l, RK) for l, WtP in enumerate(comp_slices)])

    #check the reconstruction
    if np.allclose(full_slices, recon_slices):
        print("\r", i, "Success",  end="")
    else:
        print("RECONSTRUCTION FAILED!!")
        raise Exception

 99 Success