## This notebook contains code for generating parity plots that compare the SP method results with those obtained from the rigid pairwise virial formulation.

In [None]:
import os
import gzip
import glob

def read_dump_file(filename):
    """
    Reads a LAMMPS dump.*.gz file.
    Returns: timestep, natoms, box_style_line, box_bounds, atom_prop_line, properties, atoms
    """
    with gzip.open(filename, 'rt') as f:
        # Skip until TIMESTEP
        line = f.readline().strip()
        while line and "ITEM: TIMESTEP" not in line:
            line = f.readline().strip()
        # Next line is timestep
        timestep = int(f.readline().strip())
        
        # Next: ITEM: NUMBER OF ATOMS
        line = f.readline().strip()
        if "ITEM: NUMBER OF ATOMS" not in line:
            raise ValueError("Expected 'ITEM: NUMBER OF ATOMS' line.")
        natoms = int(f.readline().strip())
        
        # Next: ITEM: BOX BOUNDS ...
        box_style_line = f.readline().strip()
        if "ITEM: BOX BOUNDS" not in box_style_line:
            raise ValueError("Expected 'ITEM: BOX BOUNDS' line.")
        box_bounds = []
        for _ in range(3):
            bounds_line = f.readline().strip().split()
            lo = float(bounds_line[0])
            hi = float(bounds_line[1])
            box_bounds.append([lo, hi])
        
        # Next: ITEM: ATOMS ...
        atom_prop_line = f.readline().strip()
        if "ITEM: ATOMS" not in atom_prop_line:
            raise ValueError("Expected 'ITEM: ATOMS' line.")
        
        # Extract property names
        parts = atom_prop_line.split()
        # parts[0]="ITEM:", parts[1]="ATOMS", rest=properties
        properties = parts[2:]
        
        # Read atoms
        atoms = []
        for _ in range(natoms):
            data = f.readline().strip().split()
            atom = {}
            for p, v in zip(properties, data):
                if p in ['id', 'type']:
                    atom[p] = int(v)
                else:
                    atom[p] = float(v)
            atoms.append(atom)
        
    return timestep, natoms, box_style_line, box_bounds, atom_prop_line, properties, atoms

def write_dump_file(filename, timestep, natoms, box_style_line, box_bounds, properties, atoms):
    """
    Write out a LAMMPS dump file with the given data.
    The properties and atoms should be updated with the new err_ij columns.
    """
    with gzip.open(filename, 'wt') as f:
        f.write("ITEM: TIMESTEP\n")
        f.write(f"{timestep}\n")
        f.write("ITEM: NUMBER OF ATOMS\n")
        f.write(f"{natoms}\n")
        f.write(box_style_line + "\n")
        for lohi in box_bounds:
            f.write(f"{lohi[0]} {lohi[1]}\n")
        
        # Rebuild the ITEM: ATOMS line from properties
        f.write("ITEM: ATOMS " + " ".join(properties) + "\n")
        
        for atom in atoms:
            vals = []
            for p in properties:
                val = atom[p]
                if p in ['id', 'type']:
                    vals.append(str(int(val)))
                else:
                    vals.append(f"{val:g}")
            f.write(" ".join(vals) + "\n")

def process_all_pairs(folder1, folder2, outfolder):
    """
    Process all pairs of files from folder1 and folder2.
    The first file in folder1 pairs with the first in folder2, second with second, etc.
    Check that their timesteps match. Compute err_ij values and write out a new combined file.
    """
    def extract_timestep(filename):
        # Assuming filenames like: "dump.<number>.gz" or "dump.<number>.cfg.gz"    
        # For "dump.200.gz", split by '.' gives ["dump", "200", "gz"]
        # The numeric part is index 1.
        base = os.path.basename(filename)
        parts = base.split('.')
        # Find the numeric part - it should be the second element if format is always consistent.
        # If the files are always "dump.<number>..." we can safely do:
        number_str = parts[1]
        return int(number_str)

    files1 = glob.glob(os.path.join(folder1, "dump.*.cfg.gz"))
    files1.sort(key=extract_timestep)
    files2 = glob.glob(os.path.join(folder2, "dump.*.gz"))
    files2.sort(key=extract_timestep)
    print(files1, files2)
    
    # Ensure same number of files
    if len(files1) != len(files2):
        raise ValueError("Mismatch in number of files between folder1 and folder2.")
    if not os.path.exists(outfolder):
        os.makedirs(outfolder)
    
    for f1, f2 in zip(files1, files2):
        # Read both files
        t1, natoms1, box_style1, box_bounds1, atom_prop_line1, props1, atoms1 = read_dump_file(f1)
        t2, natoms2, box_style2, box_bounds2, atom_prop_line2, props2, atoms2 = read_dump_file(f2)
        
        # Check timestep
        if t1 != t2:
            raise ValueError(f"Timestep mismatch between {f1} and {f2}: {t1} vs {t2}")
        # Check number of atoms
        if natoms1 != natoms2:
            raise ValueError(f"Atom count mismatch: {f1} has {natoms1}, {f2} has {natoms2}")
        # Check box bounds if necessary
        if box_bounds1 != box_bounds2:
            raise ValueError(f"Box bounds differ between {f1} and {f2}")
        # Sort second file's atoms by id
        atoms1_sorted = sorted(atoms1, key=lambda x: x['id'])
        atoms2_sorted = sorted(atoms2, key=lambda x: x['id'])
        
        # Identify required properties from the first file
        required_first = ["diff_virial_11", "diff_virial_22", "diff_virial_33", "diff_virial_12", "diff_virial_13", "diff_virial_23"]
        for r in required_first:
            if r not in props1:
                raise ValueError(f"Property {r} not found in {f1}")
        
        # Identify required properties from the second file
        # We previously used c_4[1], c_4[2], etc.
        required_second = ["c_4[1]", "c_4[2]", "c_4[3]", "c_4[4]", "c_4[5]", "c_4[6]"]
        for r in required_second:
            if r not in props2:
                raise ValueError(f"Property {r} not found in {f2}")
        
        # Append err properties
        err_props = ["err_11", "err_22", "err_33", "err_12", "err_13", "err_23"]
        props1.extend(err_props)
        
        # Compute errors
        for a1, a2 in zip(atoms1_sorted, atoms2_sorted):
            if a1['id'] != a2['id']:
                raise ValueError(f"Atom IDs do not match after sorting between {f1} and {f2}.")
            
            a1['err_11'] = a1['diff_virial_11'] - a2['c_4[1]'] * 6.242 * 10**(-7)
            a1['err_22'] = a1['diff_virial_22'] - a2['c_4[2]'] * 6.242 * 10**(-7)
            a1['err_33'] = a1['diff_virial_33'] - a2['c_4[3]'] * 6.242 * 10**(-7)
            a1['err_12'] = a1['diff_virial_12'] - a2['c_4[4]'] * 6.242 * 10**(-7)
            a1['err_13'] = a1['diff_virial_13'] - a2['c_4[5]'] * 6.242 * 10**(-7)
            a1['err_23'] = a1['diff_virial_23'] - a2['c_4[6]'] * 6.242 * 10**(-7)        

        # Define the component mappings
        components = [
            ('11', 'c_4[1]', 'xx'),
            ('22', 'c_4[2]', 'yy'),
            ('33', 'c_4[3]', 'zz'),
            ('12', 'c_4[4]', 'xy'),
            ('13', 'c_4[5]', 'xz'),
            ('23', 'c_4[6]', 'yz')
        ]
        scale_factor = 6.242e-7
       
        import matplotlib.pyplot as plt
        # Prepare figure with subplots
        fig, axes = plt.subplots(2, 3, figsize=(12, 6))
        axes = axes.flatten()

        for i, (comp, c4_key, comp_name) in enumerate(components):

            ax = axes[i]

            # Extract data for this component
            x_data = [a2[c4_key] * scale_factor for a1, a2 in zip(atoms1_sorted, atoms2_sorted)]
            y_data = [a1['diff_virial_' + comp] for a1, a2 in zip(atoms1_sorted, atoms2_sorted)]

            # Determine plotting range
            min_val = min(min(x_data), min(y_data))
            max_val = max(max(x_data), max(y_data))
            line_vals = [min_val, max_val]
            
            # Perfect fit line (y = x)
            ax.plot(line_vals, line_vals, 'k--', label='(y = x)')
            # # ±1 error lines
            # ax.plot(line_vals, [v + 0.2 for v in line_vals], 'k--', label='±1 error')
            # ax.plot(line_vals, [v - 0.2 for v in line_vals], 'k--')

            ax.scatter(x_data, y_data, s=18, color='#f13e36', alpha=0.2, edgecolor='black', linewidth=0.3,  label='Data points')

            ax.set_xlabel(f"Pairwise: v_{comp_name} (eV)", fontsize=16)
            ax.set_ylabel(f"SP: v_{comp_name} (eV)", fontsize=16)

        # Add legend to the first subplot only (to avoid clutter)
        axes[0].legend()

        filename = os.path.join(outfolder, f"diff_{t1}.png")
        # Adjust layout and save the figure
        plt.tight_layout()
        plt.savefig(filename, dpi=300)
        plt.close()

        # Construct output filename based on timestep
        outfile = os.path.join(outfolder, f"dump.err.{t1}.gz")
        
        # Write output
        # write_dump_file(outfile, t1, natoms1, box_style1, box_bounds1, props1, atoms1)
        print(f"Processed {f1} & {f2} -> {outfile}")


In [None]:
# File paths
folder = './SP_stress/'
folder1 = folder + 'Cu_defect_Morse/dumpVirial_delta1e-2' 
folder2 = folder + 'Cu_defect_Morse/dump_original'
outfolder =  folder + 'Cu_defect_Morse/dump_Cu_defect_Morse_err_delta1e-2/plots'

process_all_pairs(folder1, folder2, outfolder)