Skip to content

Commit

Permalink
Fix at comparegsi network calculation comparison for switch to polars…
Browse files Browse the repository at this point in the history
… estimators (#149)
  • Loading branch information
lukeshingles committed Feb 6, 2024
1 parent 87b0098 commit c43186d
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 180 deletions.
1 change: 1 addition & 0 deletions artistools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from artistools import commands
from artistools import deposition
from artistools import estimators
from artistools import gsinetwork
from artistools import inputmodel
from artistools import lightcurve
from artistools import macroatom
Expand Down
1 change: 0 additions & 1 deletion artistools/estimators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from artistools.estimators.estimators import get_averaged_estimators
from artistools.estimators.estimators import get_averageexcitation
from artistools.estimators.estimators import get_ionrecombrates_fromfile
from artistools.estimators.estimators import get_partiallycompletetimesteps
from artistools.estimators.estimators import get_units_string
from artistools.estimators.estimators import get_variablelongunits
from artistools.estimators.estimators import get_variableunits
Expand Down
31 changes: 13 additions & 18 deletions artistools/estimators/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,26 @@ def scan_estimators(
).lazy()

# print(f" matching cells {match_modelgridindex} and timesteps {match_timestep}")

mpiranklist = at.get_mpiranklist(modelpath, only_ranks_withgridcells=True)
mpirank_groups = list(batched(mpiranklist, 100))
mpiranks_matched = (
{at.get_mpirankofcell(modelpath=modelpath, modelgridindex=mgi) for mgi in match_modelgridindex}
if match_modelgridindex
else set(mpiranklist)
)
mpirank_groups = [
(batchindex, mpiranks)
for batchindex, mpiranks in enumerate(batched(mpiranklist, 100))
if mpiranks_matched.intersection(mpiranks)
]

runfolders = at.get_runfolders(modelpath, timesteps=match_timestep)

parquetfiles = (
get_rankbatch_parquetfile(modelpath, runfolder, mpiranks, batchindex=batchindex)
for runfolder in runfolders
for batchindex, mpiranks in enumerate(mpirank_groups)
for batchindex, mpiranks in mpirank_groups
)
assert bool(parquetfiles)

pldflazy = pl.concat([pl.scan_parquet(pfile) for pfile in parquetfiles], how="diagonal_relaxed")
pldflazy = pldflazy.unique(["timestep", "modelgridindex"], maintain_order=True, keep="first")
Expand Down Expand Up @@ -461,18 +471,3 @@ def get_averageexcitation(
energypopsum += energy_boltzfac_sum * superlevelrow.n_NLTE / boltzfac_sum

return energypopsum / ionpopsum


def get_partiallycompletetimesteps(estimators: dict[tuple[int, int], dict[str, t.Any]]) -> list[int]:
"""During a simulation, some estimator files can contain information for some cells but not others
for the current timestep.
"""
timestepcells: dict[int, list[int]] = {}
all_mgis = set()
for nts, mgi in estimators:
if nts not in timestepcells:
timestepcells[nts] = []
timestepcells[nts].append(mgi)
all_mgis.add(mgi)

return [nts for nts, mgilist in timestepcells.items() if len(mgilist) < len(all_mgis)]
Loading

0 comments on commit c43186d

Please sign in to comment.