In [25]:
import os
import glob
import numpy as np

# --- Utility to check atom count ---
def get_atom_count_xyz(file_path):
    """
    Reads the first line of an .xyz file to determine atom count.
    """
    with open(file_path, 'r') as f:
        try:
            return int(f.readline().strip())
        except Exception:
            raise ValueError(f"Could not read atom count from {file_path}")

# --- Main orchestration ---
def main(directory):
    """
    Scans .xyz and .npz files in `directory`, reports and counts those with < 3 atoms.

    For .npz files, uses the 'atom_pos' array to count atoms.
    """
    small_xyz = []
    small_npz = []

    # Check .xyz files
    for xyz_path in glob.glob(os.path.join(directory, '*.xyz')):
        try:
            count = get_atom_count_xyz(xyz_path)
        except ValueError as e:
            print(e)
            continue
        if count < 3:
            small_xyz.append((os.path.basename(xyz_path), count))
        else:
            print(f"{os.path.basename(xyz_path)} has {count} atoms (>=3)")

    # Check .npz files
    for npz_path in glob.glob(os.path.join(directory, '*.npz')):
        try:
            data = np.load(npz_path)
            atom_pos = data.get('atom_pos')
            if atom_pos is None:
                raise ValueError(f"'atom_pos' array missing in {npz_path}")
            count = atom_pos.shape[0]
        except Exception as e:
            print(e)
            continue
        if count < 3:
            small_npz.append((os.path.basename(npz_path), count))
        else:
            print(f"{os.path.basename(npz_path)} has {count} atoms (>=3)")

    # Summary of small files
    total_small = len(small_xyz) + len(small_npz)
    print()
    if total_small:
        print(f"Found {total_small} file(s) with fewer than 3 atoms:")
        if small_xyz:
            print(f" - {len(small_xyz)} .xyz file(s):")
            for fname, cnt in small_xyz:
                print(f"     • {fname}: {cnt} atoms")
        if small_npz:
            print(f" - {len(small_npz)} .npz file(s):")
            for fname, cnt in small_npz:
                print(f"     • {fname}: {cnt} atoms")
    else:
        print("No files with fewer than 3 atoms found.")

if __name__ == '__main__':
    # Replace with your directory path
    main('/scratch/phys/sin/sethih1/data_files/planar_molecules_256_old/train/')

764.npz has 16 atoms (>=3)
33776.npz has 19 atoms (>=3)
19.npz has 17 atoms (>=3)
78847.npz has 15 atoms (>=3)
19914.npz has 15 atoms (>=3)
10435.npz has 13 atoms (>=3)
15875.npz has 16 atoms (>=3)
6846.npz has 19 atoms (>=3)
1491.npz has 17 atoms (>=3)
9222.npz has 15 atoms (>=3)
456.npz has 13 atoms (>=3)
7971.npz has 12 atoms (>=3)
95303.npz has 18 atoms (>=3)
15863.npz has 17 atoms (>=3)
27492.npz has 11 atoms (>=3)
92970.npz has 19 atoms (>=3)
68401.npz has 11 atoms (>=3)
67126.npz has 10 atoms (>=3)
10366.npz has 17 atoms (>=3)
151097.npz has 18 atoms (>=3)
68442.npz has 15 atoms (>=3)
4649.npz has 18 atoms (>=3)
214615.npz has 18 atoms (>=3)
237859.npz has 18 atoms (>=3)
9338.npz has 17 atoms (>=3)
96389.npz has 18 atoms (>=3)
137009.npz has 10 atoms (>=3)
123419.npz has 11 atoms (>=3)
126347.npz has 14 atoms (>=3)
72914.npz has 18 atoms (>=3)
202937.npz has 19 atoms (>=3)
141184.npz has 18 atoms (>=3)
77987.npz has 14 atoms (>=3)
88025.npz has 16 atoms (>=3)
1456.npz has 19 ato

In [14]:
npz_file = '/scratch/phys/sin/sethih1/data_files/planar_molecules_256_old/train/79510.npz' 
data = np.load(npz_file) 


In [21]:
data.files

['atom_pos', 'atomic_numbers', 'x_pos', 'y_pos', 'frequencies', 'spectrums']