# 0. Calculate the `davgs` and `dstds` using the first 10 frames.
1. `davgs.shape = (num_types, 4)`
2. `dstds.shape = (num_types, 4)`

In [1]:
import numpy as np
from matersdk.io.pwmat.output.movement import Movement
from matersdk.data.deepmd.data_system import DpLabeledSystem

from matersdk.feature.deepmd.preprocess import TildeRNormalizer

# 1. Use the first 10 frames of `DpLabeledSystem` to calculate statistic data, and init `TildeRNormalizer`

## Step 1. Initialize the `DpLabeledSystem`

In [2]:
movement_path = "/data/home/liuhanyu/hyliu/code/mlff/test/demo2/PWdata/data1/MOVEMENT"
movement = Movement(movement_path=movement_path)

dpsys = DpLabeledSystem.from_trajectory_s(trajectory_object=movement)
print(dpsys)

****************** LabeledSystem Summary *******************
	 * Images Number           : 550           
	 * Atoms Number            : 72            
	 * Virials Information     : True          
	 * Energy Deposition       : True          
	 * Elements List           :
		 - Li: 48              
		 - Si: 24              
************************************************************



## Step 1.2. Calculate the `davgs` and `dstds`, init `TildeRNormalizer`

In [3]:
structure_indices = [*range(10)]    # PWmat-MLFF 取前10帧结构计算`davg`和`dstd`
rcut = 5
rcut_smooth = 0.5
center_atomic_numbers = [3, 14]
nbr_atomic_numbers = [3, 14]
max_num_nbrs = [100, 100]
scaling_matrix = [3, 3, 3]

In [4]:
tilde_r_normalizer = TildeRNormalizer.from_dp_labeled_system(
                dp_labeled_system=dpsys,
                structure_indices=structure_indices,
                rcut=rcut,
                rcut_smooth=rcut_smooth,
                center_atomic_numbers=center_atomic_numbers,
                nbr_atomic_numbers=nbr_atomic_numbers,
                max_num_nbrs=max_num_nbrs,
                scaling_matrix=scaling_matrix
)

davgs, dstds = tilde_r_normalizer.davgs, tilde_r_normalizer.dstds

In [5]:
print("\nStep 1. davgs = ")
print(davgs)
print("\nStep 2. dstds = ")
print(dstds)


Step 1. davgs = 
[[0.0099991  0.         0.         0.        ]
 [0.01075823 0.         0.         0.        ]]

Step 2. dstds = 
[[0.03942547 0.02348297 0.02348297 0.02348297]
 [0.04493047 0.02667387 0.02667387 0.02667387]]


## Step 1.3. Normalize $\tilde{R}$ of new `DStructure` using `davgs` nad `dstds`

The `343st frame` is corresponding to `image_005`

In [8]:
new_structure = movement.get_frame_structure(idx_frame=343)
tildeR_dict, tildeR_derivative_dict = tilde_r_normalizer._normalize(structure=new_structure)

In [9]:
print("Step 1. The Rij:")
for tmp_pair, tmp_normed_tildeR in tildeR_dict.items():
    print('\t', tmp_pair, ": ", tmp_normed_tildeR.shape)
    
print("Step 2. The derivative of Rij with respect to x, y, z:")
for tmp_key, tmp_value in tildeR_derivative_dict.items():
    print("\t", tmp_key, ": ", tmp_value.shape)

Step 1. The Rij:
	 3_3 :  (48, 100, 4)
	 3_14 :  (48, 100, 4)
	 14_3 :  (24, 100, 4)
	 14_14 :  (24, 100, 4)
Step 2. The derivative of Rij with respect to x, y, z:
	 3_3 :  (48, 100, 4, 3)
	 3_14 :  (48, 100, 4, 3)
	 14_3 :  (24, 100, 4, 3)
	 14_14 :  (24, 100, 4, 3)


## Step 1.4. Compare with Fortran code

In [10]:
tmp_a = np.concatenate(
        [
            tildeR_derivative_dict["3_3"], 
            tildeR_derivative_dict["3_14"]], axis=1)
print(tmp_a.shape)
print(np.min(tmp_a))
print(np.max(tmp_a))
print(np.sum(tmp_a))

(48, 200, 4, 3)
-6.0923357024937
10.305643893427103
2619.4930606840157


In [12]:
idx_lst = []

for tmp_idx in range(550):
    new_structure = movement.get_frame_structure(idx_frame=tmp_idx)
    tildeR_dict, tildeR_derivative_dict = tilde_r_normalizer._normalize(structure=new_structure)

    tmp_a = np.concatenate(
            [
                tildeR_derivative_dict["3_3"], 
                tildeR_derivative_dict["3_14"]], axis=1)

    result_sum = np.sum(tmp_a)
    
    print(tmp_idx, result_sum)
    
    if (result_sum - 2619.4930606962866) < 1e-3:
        idx_lst.append(tmp_idx)

0 2610.3711597780357
1 2612.129263287076
2 2613.3541825450648
3 2614.0455055556004
4 2614.218048728575
5 2613.931883044359
6 2613.3184876903906
7 2612.5073042034146
8 2611.6417446515143
9 2610.847397439408
10 2610.2141071546203
11 2625.5480993685896
12 2627.608093736136
13 2629.06968679895
14 2629.955440914954
15 2630.30723343026
16 2630.1781890214993
17 2629.659891377474
18 2628.919940265754
19 2628.0980596872982
20 2627.2952334228676
21 2626.587734351678
22 2664.041567504003
23 2665.6825643678626
24 2666.708770503542
25 2667.2263218101775
26 2667.294538697651
27 2667.023659949333
28 2666.544284292784
29 2665.9858983839085
30 2665.4376967302496
31 2664.9642336769457
32 2664.619241009441
33 2651.1101516592366
34 2652.738186370075
35 2653.7785444697715
36 2654.2621528433833
37 2654.2655031565864
38 2653.9103466696306
39 2653.2947031222966
40 2652.510662552379
41 2651.6815684901535
42 2650.9098307733216
43 2650.249959041333
44 2681.8925631559778
45 2684.339256473476
46 2686.362109164873


In [60]:
print(idx_lst)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 109, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 297, 308, 309, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 341, 342, 343, 347, 348, 349, 350, 351, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 473]


# 2. Save the `TildeRNormalizer` to `hdf5 file`

In [8]:
hdf5_file_path = "./demo_normalizer.h5"

tilde_r_normalizer.to(hdf5_file_path=hdf5_file_path)

In [9]:
ll ./demo_normalizer.h5

-rw-rw-r-- 1 liuhanyu 5632 Jun 25 16:48 ./demo_normalizer.h5


# 3. Init the `TildeRNormalizer` from `hdf5` file

In [10]:
new_trn = TildeRNormalizer.from_file(hdf5_file_path=hdf5_file_path)
print(new_trn)

*************************** TildeRNormalizer Summary ***************************
	 * rcut                      :       6.500000
	 * rcut_smooth               :       6.000000
	 * center_atomic_numbers:    :	 [ 3 14]
	 * nbr_atomic_numbers:       :	 [ 3 14]
	 * max_num_nbrs              :	 [100  80]
	 * scaling_matrix            :	 [3 3 3]
	 * davgs                     :	
[[0.06974313 0.         0.         0.        ]
 [0.06922328 0.         0.         0.        ]]
	 * dstds                     :	
[[0.11278804 0.07656205 0.07656205 0.07656205]
 [0.1140824  0.07704253 0.07704253 0.07704253]]
********************************************************************************

