In [29]:
from glob import glob
from polyphys.manage import parser
from polyphys.manage.utils import sort_filenames
import ipytest
import pytest
ipytest.autoconfig()


In [2]:
a = ('a', 1, 'c', [1, 2])
a[-1].append(3)
a

('a', 1, 'c', [1, 2, 3])

In [9]:
paragraph = """ArithmeticError
ArithmeticError is the base class for all errors that occur for numeric calculations. It is a subclass of Exception and is used to indicate errors that occur during arithmetic operations, such as division by zero or overflow errors.
It is not raised directly, but rather by its subclasses, such as ZeroDivisionError or OverflowError. These subclasses provide more specific information about the type of arithmetic error that occurred.
"""
paragraph = [word for sentence in paragraph for word in sentence.split()]
paragraph

['A',
 'r',
 'i',
 't',
 'h',
 'm',
 'e',
 't',
 'i',
 'c',
 'E',
 'r',
 'r',
 'o',
 'r',
 'A',
 'r',
 'i',
 't',
 'h',
 'm',
 'e',
 't',
 'i',
 'c',
 'E',
 'r',
 'r',
 'o',
 'r',
 'i',
 's',
 't',
 'h',
 'e',
 'b',
 'a',
 's',
 'e',
 'c',
 'l',
 'a',
 's',
 's',
 'f',
 'o',
 'r',
 'a',
 'l',
 'l',
 'e',
 'r',
 'r',
 'o',
 'r',
 's',
 't',
 'h',
 'a',
 't',
 'o',
 'c',
 'c',
 'u',
 'r',
 'f',
 'o',
 'r',
 'n',
 'u',
 'm',
 'e',
 'r',
 'i',
 'c',
 'c',
 'a',
 'l',
 'c',
 'u',
 'l',
 'a',
 't',
 'i',
 'o',
 'n',
 's',
 '.',
 'I',
 't',
 'i',
 's',
 'a',
 's',
 'u',
 'b',
 'c',
 'l',
 'a',
 's',
 's',
 'o',
 'f',
 'E',
 'x',
 'c',
 'e',
 'p',
 't',
 'i',
 'o',
 'n',
 'a',
 'n',
 'd',
 'i',
 's',
 'u',
 's',
 'e',
 'd',
 't',
 'o',
 'i',
 'n',
 'd',
 'i',
 'c',
 'a',
 't',
 'e',
 'e',
 'r',
 'r',
 'o',
 'r',
 's',
 't',
 'h',
 'a',
 't',
 'o',
 'c',
 'c',
 'u',
 'r',
 'd',
 'u',
 'r',
 'i',
 'n',
 'g',
 'a',
 'r',
 'i',
 't',
 'h',
 'm',
 'e',
 't',
 'i',
 'c',
 'o',
 'p',
 'e',
 'r',
 'a'

In [None]:
import random
from random import choice

In [5]:
shuffled = random.shuffle([1, 2, 3, 4, 5])

In [None]:
"""\
generate_test_data
==================

Utility functions for generating and preparing LAMMPS test data files.

This module contains tools to transform large raw simulation data files
(e.g., LAMMPS trajectory files) into smaller, compressed test data files
used internally in the PolyPhys test suite. The original source data files
are not part of the repository, but the functions here document and 
automate the generation process for transparency and reproducibility.

Functions
---------
- read_n_atoms(trj_path)
    Read number of atoms from line 4 of a LAMMPS trajectory file.

- split_trj_gz(trj_path, n_frames, n_files, output_prefix)
    Extract the last N frames from a LAMMPS trajectory file and write them
    as compressed `.trj.gz` test files, optionally overlapping frames between
    files.

Notes
-----
This module is intended for development use only and should not be imported
as part of the core PolyPhys package.
"""


import gzip
import logging
from typing import Optional, Union, Tuple
from pathlib import Path

logger = logging.getLogger(__name__)


def find_frame_pairs(
    a_n_wholes, a_n_segments,
    b_n_wholes, b_n_segments,
    a_dump_freq=5000, b_dump_freq=2000,
    a_continuous=True, b_continuous=True,
    max_a_frames=10000, max_b_frames=10000,
    a_fps_min=50
):
    """
    Find valid frame counts per segment for a-type and b-type trajectory files
    such that the total number of timesteps from initial to final match between
    the two types, accounting for dumping frequency and segment continuity.

    Parameters
    ----------
    a_n_wholes : int
        Number of whole a-type trajectory sets.
    a_n_segments : int
        Number of a-type segments per whole.
    b_n_wholes : int
        Number of whole b-type trajectory sets.
    b_n_segments : int
        Number of b-type segments per whole.
    a_dump_freq : int, optional
        Timestep interval between a-type frames. Default is 5000.
    b_dump_freq : int, optional
        Timestep interval between b-type frames. Default is 2000.
    a_continuous : bool, optional
        Whether a-type segments overlap by one frame (continuous). Default is True.
    b_continuous : bool, optional
        Whether b-type segments overlap by one frame (continuous). Default is True.
    max_a_frames : int, optional
        Maximum allowed total frame count for a-type. Default is 10000.
    max_b_frames : int, optional
        Maximum allowed total frame count for b-type. Default is 10000.
    a_fps_min : int, optional
        Minimum value to try for `a_frames_per_segment`. Default is 50.

    Returns
    -------
    list of tuple of int
        A list of tuples `(a_frames_per_segment, b_frames_per_segment)` that satisfy
        timestep equivalence and frame count constraints.

    Notes
    -----
    This function ensures that the **total number of simulation timesteps** (not frames)
    between the first and last frame is the same for both a-type and b-type outputs.

    Let:
        - `N_a = a_n_wholes * a_n_segments`
        - `N_b = b_n_wholes * b_n_segments`
        - `A` = a_frames_per_segment
        - `B` = b_frames_per_segment

    Then:
        total_a_frames = N_a * A - int(a_continuous) * (N_a - 1) - 1
        total_a_timesteps = total_a_frames * a_dump_freq

    To match timesteps, solve:
        total_b_timesteps = total_a_timesteps
        → total_b_frames = total_a_timesteps // b_dump_freq + 1

    Then compute:
        B = (total_b_frames + int(b_continuous) * (N_b - 1)) / N_b

    The function only returns pairs where `B` is a positive integer and both
    `total_a_frames` and `total_b_frames` stay within user-specified limits.
    """
    results = []
    N_a = a_n_wholes * a_n_segments
    N_b = b_n_wholes * b_n_segments
    a_adj = int(a_continuous) * (N_a -1)
    b_adj = int(b_continuous) * (N_b -1)

    for a_fps in range(a_fps_min, max_a_frames + 1):
        total_a_frames = N_a * a_fps - a_adj - 1 # 1 for inclusive counting
        if total_a_frames > max_a_frames:
            continue

        total_a_timesteps = total_a_frames * a_dump_freq

        # Solve for total_b_frames
        if total_a_timesteps % b_dump_freq != 0:
            continue

        total_b_frames = total_a_timesteps // b_dump_freq + 1 # 1 for inclusive counting
        if total_b_frames > max_b_frames:
            continue

        # Solve for b_fps
        numerator = total_b_frames + b_adj
        if numerator % N_b != 0:
            continue

        b_fps = numerator // N_b
        if b_fps < 1:
            continue

        results.append((a_fps, b_fps))

    return results


def get_first_last_timesteps(file_path: str) -> Tuple[int, int]:
    """
    Return the first and last timestep from a LAMMPS trajectory (dump) file.

    Supports both plain text and gzip-compressed files.

    Parameters
    ----------
    file_path : str
        Path to the LAMMPS dump file (.lammpstrj or .gz).

    Returns
    -------
    tuple of int
        A tuple containing:
        - The timestep of the first frame.
        - The timestep of the last frame.

    Raises
    ------
    ValueError
        If the file does not contain any recognizable `ITEM: TIMESTEP` entries.

    Notes
    -----
    This function reads the first frame sequentially and the last frame by scanning
    from the end of the file, which is efficient even for large files.
    For `.gz` files, the entire file must be scanned (not seekable), which is slower.
    """
    is_gz = file_path.endswith(".gz")
    open_func = gzip.open if is_gz else open

    # Get first timestep
    with open_func(file_path, "rt") as f:
        for line in f:
            if line.startswith("ITEM: TIMESTEP"):
                first_ts = int(next(f).strip())
                break

    if not is_gz:
        # Efficient seek-based tail scan for uncompressed files
        with open(file_path, "r") as f:
            f.seek(0, 2)
            file_size = f.tell()
            block_size = 8192
            buffer = ''
            pos = file_size

            while pos > 0:
                read_size = min(block_size, pos)
                pos -= read_size
                f.seek(pos)
                buffer = f.read(read_size) + buffer
                lines = buffer.splitlines()

                for i in range(len(lines) - 1, -1, -1):
                    if lines[i].startswith("ITEM: TIMESTEP"):
                        last_ts = int(lines[i + 1].strip())
                        return first_ts, last_ts
    else:
        # Fallback: linear scan for gzip files
        with gzip.open(file_path, "rt") as f:
            last_ts = None
            for line in f:
                if line.startswith("ITEM: TIMESTEP"):
                    last_ts = int(next(f).strip())
            if last_ts is not None:
                return first_ts, last_ts

    raise ValueError("Could not find timestep entries in the file.")


def read_n_atoms_from_lammps_trj(trj_path: Union[str, Path]) -> int:
    """Read number of atoms from line 4 of a LAMMPS trajectory file."""
    trj_path = Path(trj_path)
    opener = gzip.open if trj_path.suffix == ".gz" else open
    with opener(trj_path, "rt") as f:
        for i, line in enumerate(f):
            if i == 3:
                return int(line.strip())
    raise ValueError("File too short to contain number of atoms at line 4.")


def split_lammps_trj(
    filename: Union[str, Path],
    n_wholes: int = 1,
    n_segments: int = 2,
    frames_per_segment: int = 50,
    prefix: Optional[str] = None,
    suffix: Optional[str] = None,
    continuous: bool = True,
    compress_output: bool = False,
    save_to: Optional[Union[str, Path]] = None
) -> None:
    """
    Split a LAMMPS trajectory file into smaller test segments, with optional
    frame overlap and compression.

    Parameters
    ----------
    filename : str or Path
        Path to the input LAMMPS trajectory file (.lammpstrj or .gz).
    n_wholes : int, optional
        Number of ensemble wholes to generate. Default is 1.
    n_segments : int, optional
        Number of segments per whole. Default is 2.
    frames_per_segment : int, optional
        Number of frames in each segment. Default is 50.
    prefix : str, optional
        Optional prefix for output filenames.
    suffix : str, optional
        Optional suffix for output filenames.
    continuous : bool, optional
        If True, segments will overlap by one frame. Default is True.
    compress_output : bool, optional
        If True, output files will be gzip-compressed. Default is False.
    save_to : str or Path, optional
        Directory to save output files. If None, uses input file's directory.

    Raises
    ------
    ValueError
        If the input file contains fewer lines than needed for splitting.

    Notes
    -----
    For `continuous=True`, segments overlap by one frame (e.g., frame N of one
    segment is reused as frame 0 of the next). For `continuous=False`, segments
    are disjoint.
    """
    trj_path = Path(filename)
    n_atoms = read_n_atoms_from_lammps_trj(trj_path)
    lines_per_frame = n_atoms + 9
    n_files = n_wholes * n_segments

    n_frames = n_files * frames_per_segment - int(continuous) * (n_files - 1)
    total_lines = n_frames * lines_per_frame

    opener = gzip.open if trj_path.suffix == ".gz" else open
    with opener(trj_path, "rt") as f:
        all_lines = [next(f) for _ in range(total_lines)]

    if len(all_lines) < total_lines:
        raise ValueError(f"Input file has insufficient frames to split: "
                         f"required {n_frames}, found {len(all_lines) // lines_per_frame}")

    frames = [
        all_lines[i * lines_per_frame: (i + 1) * lines_per_frame]
        for i in range(n_frames)
    ]

    out_ext = ".lammpstrj.gz" if compress_output else ".lammpstrj"

    if prefix and not prefix.endswith("."):
        prefix += "."

    output_dir = Path(save_to) if save_to is not None else trj_path.parent
    output_dir.mkdir(parents=True, exist_ok=True)

    for i in range(n_wholes):
        for j in range(n_segments):
            file_index = i * n_segments + j
            start = (
                file_index * (frames_per_segment - 1)
                if continuous
                else file_index * frames_per_segment
            )
            end = start + frames_per_segment
            chunk = frames[start:end]

            if n_segments == 1:
                out_name = f"{prefix or ''}ens{i+1}.{suffix or 'part'}{out_ext}"
            else:
                out_name = f"{prefix or ''}ens{i+1}.j{j+1}.{suffix or 'part'}{out_ext}"

            out_path = output_dir / out_name
            out_opener = gzip.open if compress_output else open

            with out_opener(out_path, "wt") as f:
                for frame in chunk:
                    f.writelines(frame)

            logger.info(f"Wrote {len(chunk)} frames to {out_path}")



In [55]:
a_n_wholes = 1
a_n_segments = 4
b_n_wholes = 1
b_n_segments = 1
a_continuous = True
b_continuous = False
a_dump_freq = 5000
b_dump_freq = 2000

res = find_frame_pairs(
    a_n_wholes,
    a_n_segments,
    b_n_wholes,
    b_n_segments,
    a_dump_freq=a_dump_freq,
    b_dump_freq=b_dump_freq,
    a_continuous=True,
    b_continuous=False,
    max_a_frames=1001,
    max_b_frames=2501,
    a_fps_min=50
)
print(res)
#a_fps, b_fps = res[0]
#print(f"A = {a_fps}, B = {b_fps}")

[(50, 491), (51, 501), (52, 511), (53, 521), (54, 531), (55, 541), (56, 551), (57, 561), (58, 571), (59, 581), (60, 591), (61, 601), (62, 611), (63, 621), (64, 631), (65, 641), (66, 651), (67, 661), (68, 671), (69, 681), (70, 691), (71, 701), (72, 711), (73, 721), (74, 731), (75, 741), (76, 751), (77, 761), (78, 771), (79, 781), (80, 791), (81, 801), (82, 811), (83, 821), (84, 831), (85, 841), (86, 851), (87, 861), (88, 871), (89, 881), (90, 891), (91, 901), (92, 911), (93, 921), (94, 931), (95, 941), (96, 951), (97, 961), (98, 971), (99, 981), (100, 991), (101, 1001), (102, 1011), (103, 1021), (104, 1031), (105, 1041), (106, 1051), (107, 1061), (108, 1071), (109, 1081), (110, 1091), (111, 1101), (112, 1111), (113, 1121), (114, 1131), (115, 1141), (116, 1151), (117, 1161), (118, 1171), (119, 1181), (120, 1191), (121, 1201), (122, 1211), (123, 1221), (124, 1231), (125, 1241), (126, 1251), (127, 1261), (128, 1271), (129, 1281), (130, 1291), (131, 1301), (132, 1311), (133, 1321), (134, 13

In [56]:

trj_pairs = glob("/Users/amirhsi/research_data/TransFociCyl/all_simulations/D20nl5ns400al6ac1-all_simulations/epss*/epss*lammpstrj")
trj_pairs = sort_filenames(
    trj_pairs,
    formats=['j20.ring.all.lammpstrj', '.ring.bug.lammpstrj']
)

@pytest.mark.parametrize("all_path,bug_path", trj_pairs)

def test_matching_timesteps_across_adump_bdump(
    all_path, bug_path
):
    # Assume files are in place
    assert all_path and bug_path, "Missing input test files"

    a_n_wholes = 1
    a_n_segments = 4
    b_n_wholes = 1
    b_n_segments = 1
    a_continuous = True
    b_continuous = False
    max_a_frames = 1001
    max_b_frames = 2501  

    save_to = f"split-{a_n_wholes}_{a_n_segments}_{str(a_continuous)}_{b_n_wholes}_{b_n_segments}_{str(b_continuous)}"
    a_sim = parser.TransFociCyl(all_path, "segment", "all", ispath=True)
    b_sim = parser.TransFociCyl(bug_path, "whole", "bug", ispath=True)

    a_suffix = f"{a_sim.topology}.{a_sim.group}"
    b_suffix = f"{b_sim.topology}.{b_sim.group}"


    res = find_frame_pairs(
        a_n_wholes,
        a_n_segments,
        b_n_wholes,
        b_n_segments,
        a_continuous=a_continuous,
        b_continuous=b_continuous,
        max_a_frames=max_a_frames,
        max_b_frames=max_b_frames,
    )

    a_fps, b_fps = res[0]
    print(f"A = {a_fps}, B = {b_fps}")

    # --- Split adump ---
    split_lammps_trj(
        all_path,
        n_wholes=a_n_wholes,
        n_segments=a_n_segments,
        frames_per_segment=a_fps,
        prefix=a_sim.ensemble_long,
        suffix=a_suffix,
        continuous=a_continuous,
        compress_output=False,
        save_to=save_to
    )

    # --- Split bdump ---
    split_lammps_trj(
        bug_path,
        n_wholes=b_n_wholes,
        n_segments=b_n_segments,
        frames_per_segment=b_fps,
        prefix=b_sim.ensemble_long,
        suffix=b_suffix,
        continuous=b_continuous,
        compress_output=False,
        save_to=save_to
    )

    # --- Get first/last paths ---
    all_path = Path(all_path)
    bug_path = Path(bug_path)
    
    
    a_split_first_path = all_path.parent / Path(f"{save_to}/{a_sim.ensemble_long}.ens1.j1.{a_suffix}.lammpstrj")
    a_split_last_path = all_path.parent / Path(f"{save_to}/{a_sim.ensemble_long}.ens{a_n_wholes}.j{a_n_segments}.{a_suffix}.lammpstrj")

    b_split_first_path = bug_path.parent / Path(f"{save_to}/{b_sim.ensemble_long}.ens1.{b_suffix}.lammpstrj")
    b_split_last_path = bug_path.parent / Path(f"{save_to}/{b_sim.ensemble_long}.ens{b_n_wholes}.{b_suffix}.lammpstrj")


    a_first, _ = get_first_last_timesteps(str(a_split_first_path))
    _, a_last = get_first_last_timesteps(str(a_split_last_path)) 
    b_first, _ = get_first_last_timesteps(str(b_split_first_path))
    _, b_last = get_first_last_timesteps(str(b_split_last_path))

    assert a_first == b_first, f"First timesteps differ: a={a_first}, b={b_first}"
    assert a_last == b_last, f"Last timesteps differ: a={a_last}, b={b_last}"

ipytest.run()

[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m                                                                                 [100%][0m
[32m[32m[1m12 passed[0m[32m in 7.05s[0m[0m


<ExitCode.OK: 0>