Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DataCollection.plot_missing_data() to visualize missing data #402

Merged
merged 13 commits into from Jun 5, 2023

Conversation

JamesWrigley
Copy link
Member

I've recently started making DAMNIT variables to plot the missing data for various sources in a run, and I figured it'd be useful to have something like that in extra-data considering how frequently it comes up. Originally I did this with line and scatter plots, but those get completely unreadable with lots of sources, so I opted for horizontal broken bar plots instead.

The downside of the bar plots is that if the plot is too small then it's impossible to see infrequent missing/present trains, but I think that's ok (you can see an example in the GIF below when min_saved_pct=100). One way around that would be to have a tiny scatter plot above the bars showing where the missing/present trains are.

Things I'm unsure about:

  • The check for missing trains might not be robust, it's basically doing this: set(source_tids) != set(run_tids). But that'll kinda break if train IDs repeat, does that ever happen?
  • The code doesn't check the trains for each source-key pair individually, only the first key for the source. Is that ok? I think it is because if I remember correctly the DAQ drops entire hashes at a time if any data in the hash is missing.
  • I called the method plot_missing_data() because technically it's the data for a source that's missing, not entire trains from the run. But I think everyone calls this problem 'missing trains' so maybe it should be named plot_missing_trains() instead? I don't have a strong opinion about that.

Examples:
plot-missing-data

image

image

@tmichela
Copy link
Member

tmichela commented May 16, 2023

pretty cool! Would be amazing as a console script, but probably much harder to render it 😁
I wonder if it would also be useful to add markers for start/end of individual files(?)

The check for missing trains might not be robust, it's basically doing this: set(source_tids) != set(run_tids). But that'll kinda break if train IDs repeat, does that ever happen?

If that happens it is a broken file and extra-data-validate will report it.

The code doesn't check the trains for each source-key pair individually, only the first key for the source. Is that ok? I think it is because if I remember correctly the DAQ drops entire hashes at a time if any data in the hash is missing.

This is fine as you have a single index entry per source not per key, except for xtdf data which have indexes per sub-section. In this case you probably want to check the image section.

I called the method plot_missing_data() because technically it's the data for a source that's missing, not entire trains from the run. But I think everyone calls this problem 'missing trains' so maybe it should be named plot_missing_trains() instead? I don't have a strong opinion about that.

I think either is fine (I prefer plot_missing_data).

@takluyver
Copy link
Member

Nice!

When I first glanced at these graphs, I thought red was data and black was gaps. Looking more closely, I realise I had that backwards. I wonder if there's anything we can do to make it more obvious? E.g. different colours? I guess you were going for red=bad, whereas I guessed black=empty. Maybe something like blue/green for data and black/white for missing? 🤔 Obviously we could have a legend, but it takes up space, and if you think you understand, you might not look at the legend.

How fast is it with virtual overview files? Maybe we should just push for making those in more situations? I'd like to avoid detecting 'am I in a notebook' and drawing progress bars if possible.

@JamesWrigley
Copy link
Member Author

Here are some profiles of plot_missing_data(), first with only raw data selected (i.e. reading from a single voview file):

p3422, run 158. Total running time: 4.7s, bottleneck is SourceData.keys()
Timer unit: 1e-06 s

Total time: 0.213708 s
File: /home/wrigleyj/src/EXtra-data/extra_data/file_access.py
Function: _read_index at line 330

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   330                                               def _read_index(self, source, group):
   331                                                   """Get first index & count for a source.
   332                                           
   333                                                   This is 'real' reading when the requested index is not in the cache.
   334                                                   """
   335        79        108.0      1.4      0.1          ntrains = len(self.train_ids)
   336        79     157741.0   1996.7     73.8          ix_group = self.file['/INDEX/{}/{}'.format(source, group)]
   337        79      27660.0    350.1     12.9          firsts = ix_group['first'][:ntrains]
   338        79       3178.0     40.2      1.5          if 'count' in ix_group:
   339        79      24901.0    315.2     11.7              counts = ix_group['count'][:ntrains]
   340                                                   else:
   341                                                       status = ix_group['status'][:ntrains]
   342                                                       counts = np.uint64((ix_group['last'][:ntrains] - firsts + 1) * status)
   343        79        120.0      1.5      0.1          return firsts, counts

Total time: 3.30463 s
File: /home/wrigleyj/src/EXtra-data/extra_data/file_access.py
Function: get_keys at line 373

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   373                                               def get_keys(self, source):
   374                                                   """Get keys for a given source name
   375                                           
   376                                                   Keys are found by walking the HDF5 file, and cached for reuse.
   377                                                   """
   378        79         62.0      0.8      0.0          try:
   379        79        212.0      2.7      0.0              return self._keys_cache[source]
   380        79         76.0      1.0      0.0          except KeyError:
   381        79         73.0      0.9      0.0              pass
   382                                           
   383        79         78.0      1.0      0.0          if source in self.control_sources:
   384        54         58.0      1.1      0.0              group = '/CONTROL/' + source
   385        25         19.0      0.8      0.0          elif source in self.instrument_sources:
   386        25         26.0      1.0      0.0              group = '/INSTRUMENT/' + source
   387                                                   else:
   388                                                       raise SourceNameError(source)
   389                                           
   390        79         85.0      1.1      0.0          res = set()
   391                                           
   392        79         76.0      1.0      0.0          def add_key(key, value):
   393                                                       if isinstance(value, h5py.Dataset):
   394                                                           res.add(key.replace('/', '.'))
   395                                           
   396        79    3303601.0  41817.7    100.0          self.file[group].visititems(add_key)
   397        79        223.0      2.8      0.0          self._keys_cache[source] = res
   398        79         46.0      0.6      0.0          return res

Total time: 0.313287 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: _find_chunks at line 30

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    30                                               def _find_chunks(self):
    31                                                   """Find contiguous chunks of data for this key, in any order."""
    32        79      49302.0    624.1     15.7          all_tids_arr = np.array(self.train_ids)
    33                                           
    34       158        178.0      1.1      0.1          for file in self.files:
    35        79        148.0      1.9      0.0              if len(file.train_ids) == 0:
    36                                                           continue
    37                                           
    38        79     215459.0   2727.3     68.8              firsts, counts = file.get_index(self.source, self._key_group)
    39                                           
    40                                                       # Of trains in this file, which are in selection
    41        79      32400.0    410.1     10.3              include = np.isin(file.train_ids, all_tids_arr)
    42        79        144.0      1.8      0.0              if not self.inc_suspect_trains:
    43                                                           include &= file.validity_flag
    44                                           
    45                                                       # Assemble contiguous chunks of data from this file
    46       158      14516.0     91.9      4.6              for _from, _to in contiguous_regions(include):
    47       158        484.0      3.1      0.2                  yield DataChunk(
    48        79        306.0      3.9      0.1                      file, self.hdf5_data_path,
    49        79        124.0      1.6      0.0                      first=firsts[_from],
    50        79        149.0      1.9      0.0                      train_ids=file.train_ids[_from:_to],
    51        79         77.0      1.0      0.0                      counts=counts[_from:_to],
    52                                                           )

Total time: 0.615803 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: drop_empty_trains at line 146

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   146                                               def drop_empty_trains(self):
   147                                                   """Select only trains with data as a new :class:`KeyData` object."""
   148        79     343418.0   4347.1     55.8          counts = self.data_counts()
   149                                                   # Note: we do this strange UInt64Index -> ndarray -> list conversion
   150                                                   # because for some reason panda's .to_list() method generates
   151                                                   # sub-optimal lists. If .to_list() is used, then operations on the
   152                                                   # train_id's list of the new KeyData object are significantly slower
   153                                                   # (even the call to _only_tids() below).
   154        79      75849.0    960.1     12.3          tids = list(counts[counts > 0].index.to_numpy())
   155                                           
   156        79     196536.0   2487.8     31.9          return self._only_tids(list(tids))

Total time: 0.342559 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: data_counts at line 183

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   183                                               def data_counts(self, labelled=True):
   184                                                   """Get a count of data entries in each train.
   185                                           
   186                                                   If *labelled* is True, returns a pandas series with an index of train
   187                                                   IDs. Otherwise, returns a NumPy array of counts to match ``.train_ids``.
   188                                                   """
   189        79     315267.0   3990.7     92.0          if self._data_chunks:
   190        79       1636.0     20.7      0.5              train_ids = np.concatenate([c.train_ids for c in self._data_chunks])
   191        79       1286.0     16.3      0.4              counts = np.concatenate([c.counts for c in self._data_chunks])
   192                                                   else:
   193                                                       train_ids = counts = np.zeros(0, dtype=np.uint64)
   194                                           
   195        79         85.0      1.1      0.0          if labelled:
   196        79        147.0      1.9      0.0              import pandas as pd
   197        79      24138.0    305.5      7.0              return pd.Series(counts, index=train_ids)
   198                                                   else:
   199                                                       all_tids_arr = np.array(self.train_ids)
   200                                                       res = np.zeros(len(all_tids_arr), dtype=np.uint64)
   201                                                       tid_to_ix = np.intersect1d(all_tids_arr, train_ids, return_indices=True)[1]
   202                                           
   203                                                       # We may be missing some train IDs, if they're not in any file
   204                                                       # for this source, and they're sometimes out of order within chunks
   205                                                       # (they shouldn't be, but we try not to fail too badly if they are).
   206                                                       # assert np.isin(train_ids, all_tids_arr).all()
   207                                                       assert len(tid_to_ix) == len(train_ids)
   208                                                       res[tid_to_ix] = counts
   209                                           
   210                                                       return res

Total time: 4.7228 s
File: /home/wrigleyj/src/EXtra-data/extra_data/reader.py
Function: plot_missing_data at line 1333

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
  1333                                               def plot_missing_data(self, min_saved_pct=95):
  1334                                                   """Plot sources that have missing data for some trains.
  1335                                           
  1336                                                   Parameters
  1337                                                   ----------
  1338                                           
  1339                                                   min_saved_pct: int or float, optional
  1340                                                       Only show sources with less than this percentage of trains saved.
  1341                                                   """
  1342         1          4.0      4.0      0.0          n_trains = len(self.train_ids)
  1343                                           
  1344                                                   # Helper function that returns an alias for a source if one is
  1345                                                   # available, and the source name otherwise.
  1346         1          2.0      2.0      0.0          def best_src_name(src):
  1347                                                       for alias, alias_ident in self._aliases.items():
  1348                                                           if isinstance(alias_ident, str) and alias_ident == src:
  1349                                                               return alias
  1350                                           
  1351                                                       return src
  1352                                           
  1353                                                   # If possible, create a progress bar. Loading the train IDs for every
  1354                                                   # source can be very slow (tens of seconds) so it's good to give the
  1355                                                   # user some feedback to know that the method hasn't frozen.
  1356         1          1.0      1.0      0.0          display_progress = False
  1357         1          7.0      7.0      0.0          if self._running_in_notebook():
  1358         1          2.0      2.0      0.0              try:
  1359         1          9.0      9.0      0.0                  from ipywidgets import IntProgress
  1360         1          5.0      5.0      0.0                  from IPython.display import display
  1361                                                       except ImportError:
  1362                                                           pass
  1363                                                       else:
  1364         2      12908.0   6454.0      0.3                  progress_bar = IntProgress(min=0, max=len(self.all_sources),
  1365         1          1.0      1.0      0.0                                             description="Checking:")
  1366         1       4042.0   4042.0      0.1                  display(progress_bar)
  1367         1          2.0      2.0      0.0                  display_progress = True
  1368                                           
  1369                                                   # Find sources with missing data
  1370         1          2.0      2.0      0.0          flaky_sources = { }
  1371         1        455.0    455.0      0.0          run_tids = set(self.train_ids)
  1372        80        256.0      3.2      0.0          for src in self.all_sources:
  1373        79    3306874.0  41859.2     70.0              key = list(self[src].keys())[0]
  1374        79      34266.0    433.7      0.7              kd = self[src, key]
  1375        79     626145.0   7925.9     13.3              kd_tids = kd.drop_empty_trains().train_ids
  1376        79        349.0      4.4      0.0              save_pct = len(kd_tids) / n_trains * 100
  1377                                           
  1378        79      57565.0    728.7      1.2              if set(kd_tids) != run_tids and save_pct <= min_saved_pct:
  1379         3        139.0     46.3      0.0                  flaky_sources[best_src_name(src)] = kd_tids
  1380                                           
  1381        79        174.0      2.2      0.0              if display_progress:
  1382        79      69584.0    880.8      1.5                  progress_bar.value += 1
  1383                                           
  1384                                                   # Hide the progress bar now that we've checked all the sources
  1385         1          2.0      2.0      0.0          if display_progress:
  1386         1        571.0    571.0      0.0              progress_bar.close()
  1387                                           
  1388                                                   # Sort the flaky sources by decreasing order of how many trains they're missing
  1389         2         12.0      6.0      0.0          flaky_sources = dict(sorted(flaky_sources.items(), key=lambda x: len(x[1]),
  1390         1          2.0      2.0      0.0                                      reverse=True)
  1391                                                                        )
  1392                                           
  1393                                                   # Plot missing data
  1394         1          7.0      7.0      0.0          import matplotlib.pyplot as plt
  1395         1      48607.0  48607.0      1.0          fig, ax = plt.subplots(figsize=(9, max(2, len(flaky_sources) / 4)))
  1396                                           
  1397         1          2.0      2.0      0.0          bar_height = 0.5
  1398         4         15.0      3.8      0.0          for i, src in enumerate(flaky_sources):
  1399                                                       # First find all the trains that are missing
  1400         3        145.0     48.3      0.0              save_line = np.zeros(n_trains).astype(bool)
  1401         3       3883.0   1294.3      0.1              save_line[np.intersect1d(self.train_ids, flaky_sources[src], return_indices=True)[1]] = True
  1402                                           
  1403                                                       # Loop over each train to find blocks of trains that are either
  1404                                                       # present or missing.
  1405         3        122.0     40.7      0.0              bars = { }
  1406         3          6.0      2.0      0.0              block_start = 0
  1407     36582      66219.0      1.8      1.4              for idx in range(n_trains):
  1408     36579     116548.0      3.2      2.5                  if save_line[idx] != save_line[block_start]:
  1409                                                               # If we find a train that doesn't match the save status of
  1410                                                               # the current block, create a new entry in `bars` to record
  1411                                                               # the start index, the block length, and the save status.
  1412      3276       6752.0      2.1      0.1                      bars[(block_start, idx - block_start)] = save_line[block_start]
  1413      3276       6264.0      1.9      0.1                      block_start = idx
  1414                                           
  1415                                                       # Add the last block
  1416         3          8.0      2.7      0.0              bars[(block_start, n_trains - block_start)] = save_line[block_start]
  1417                                           
  1418                                                       # Plot all the blocks
  1419         6     282257.0  47042.8      6.0              ax.broken_barh(bars.keys(),
  1420         3          6.0      2.0      0.0                             (i, bar_height),
  1421         3        501.0    167.0      0.0                             color=["k" if x else "r" for x in bars.values()])
  1422                                           
  1423                                                   # Set labels and ticks
  1424         2         17.0      8.5      0.0          tick_labels = [f"{src} ({len(tids) / n_trains * 100:.2f}%)"
  1425         1          4.0      4.0      0.0                    for i, (src, tids) in enumerate(flaky_sources.items())]
  1426         2       5588.0   2794.0      0.1          ax.set_yticks(np.arange(len(flaky_sources)) + bar_height / 2,
  1427         1          2.0      2.0      0.0                        labels=tick_labels, fontsize=8)
  1428         1        109.0    109.0      0.0          ax.set_xlabel("Train ID index")
  1429                                           
  1430                                                   # Set title
  1431         1          4.0      4.0      0.0          title = f"Sources with less than {min_saved_pct}% of trains saved"
  1432         1         11.0     11.0      0.0          run_meta = self.run_metadata()
  1433         1          2.0      2.0      0.0          if "proposalNumber" in run_meta and "runNumber" in run_meta:
  1434                                                       title += f" in p{run_meta['proposalNumber']}, run {run_meta['runNumber']}"
  1435         1        352.0    352.0      0.0          ax.set_title(title)
  1436                                           
  1437         1      71988.0  71988.0      1.5          fig.tight_layout()
  1438                                           
  1439         1          3.0      3.0      0.0          return ax

And then with raw + proc files (769 in total):

p3422, run 159. Total running time: 15.8s, bottleneck is KeyData.drop_empty_trains()
Timer unit: 1e-06 s

Total time: 8.14472 s
File: /home/wrigleyj/src/EXtra-data/extra_data/file_access.py
Function: _read_index at line 330

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   330                                               def _read_index(self, source, group):
   331                                                   """Get first index & count for a source.
   332                                           
   333                                                   This is 'real' reading when the requested index is not in the cache.
   334                                                   """
   335       831        918.0      1.1      0.0          ntrains = len(self.train_ids)
   336       831    7470411.0   8989.7     91.7          ix_group = self.file['/INDEX/{}/{}'.format(source, group)]
   337       831     334232.0    402.2      4.1          firsts = ix_group['first'][:ntrains]
   338       831      32903.0     39.6      0.4          if 'count' in ix_group:
   339       831     305019.0    367.1      3.7              counts = ix_group['count'][:ntrains]
   340                                                   else:
   341                                                       status = ix_group['status'][:ntrains]
   342                                                       counts = np.uint64((ix_group['last'][:ntrains] - firsts + 1) * status)
   343       831       1239.0      1.5      0.0          return firsts, counts

Total time: 5.76136 s
File: /home/wrigleyj/src/EXtra-data/extra_data/file_access.py
Function: get_keys at line 373

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   373                                               def get_keys(self, source):
   374                                                   """Get keys for a given source name
   375                                           
   376                                                   Keys are found by walking the HDF5 file, and cached for reuse.
   377                                                   """
   378        79         74.0      0.9      0.0          try:
   379        79        195.0      2.5      0.0              return self._keys_cache[source]
   380        79         71.0      0.9      0.0          except KeyError:
   381        79         82.0      1.0      0.0              pass
   382                                           
   383        79         85.0      1.1      0.0          if source in self.control_sources:
   384        54         57.0      1.1      0.0              group = '/CONTROL/' + source
   385        25         30.0      1.2      0.0          elif source in self.instrument_sources:
   386        25         28.0      1.1      0.0              group = '/INSTRUMENT/' + source
   387                                                   else:
   388                                                       raise SourceNameError(source)
   389                                           
   390        79         82.0      1.0      0.0          res = set()
   391                                           
   392        79         76.0      1.0      0.0          def add_key(key, value):
   393                                                       if isinstance(value, h5py.Dataset):
   394                                                           res.add(key.replace('/', '.'))
   395                                           
   396        79    5760343.0  72915.7    100.0          self.file[group].visititems(add_key)
   397        79        201.0      2.5      0.0          self._keys_cache[source] = res
   398        79         40.0      0.5      0.0          return res

Total time: 8.68333 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: _find_chunks at line 30

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    30                                               def _find_chunks(self):
    31                                                   """Find contiguous chunks of data for this key, in any order."""
    32        79      50529.0    639.6      0.6          all_tids_arr = np.array(self.train_ids)
    33                                           
    34       910        842.0      0.9      0.0          for file in self.files:
    35       831       1504.0      1.8      0.0              if len(file.train_ids) == 0:
    36                                                           continue
    37                                           
    38       831    8162964.0   9823.1     94.0              firsts, counts = file.get_index(self.source, self._key_group)
    39                                           
    40                                                       # Of trains in this file, which are in selection
    41       831     307740.0    370.3      3.5              include = np.isin(file.train_ids, all_tids_arr)
    42       831       1421.0      1.7      0.0              if not self.inc_suspect_trains:
    43                                                           include &= file.validity_flag
    44                                           
    45                                                       # Assemble contiguous chunks of data from this file
    46      1662     146910.0     88.4      1.7              for _from, _to in contiguous_regions(include):
    47      1662       4879.0      2.9      0.1                  yield DataChunk(
    48       831       3161.0      3.8      0.0                      file, self.hdf5_data_path,
    49       831       1162.0      1.4      0.0                      first=firsts[_from],
    50       831       1506.0      1.8      0.0                      train_ids=file.train_ids[_from:_to],
    51       831        708.0      0.9      0.0                      counts=counts[_from:_to],
    52                                                           )

Total time: 9.23156 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: drop_empty_trains at line 146

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   146                                               def drop_empty_trains(self):
   147                                                   """Select only trains with data as a new :class:`KeyData` object."""
   148        79    8725458.0 110448.8     94.5          counts = self.data_counts()
   149                                                   # Note: we do this strange UInt64Index -> ndarray -> list conversion
   150                                                   # because for some reason panda's .to_list() method generates
   151                                                   # sub-optimal lists. If .to_list() is used, then operations on the
   152                                                   # train_id's list of the new KeyData object are significantly slower
   153                                                   # (even the call to _only_tids() below).
   154        79      77901.0    986.1      0.8          tids = list(counts[counts > 0].index.to_numpy())
   155                                           
   156        79     428205.0   5420.3      4.6          return self._only_tids(list(tids))

Total time: 8.7246 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: data_counts at line 183

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   183                                               def data_counts(self, labelled=True):
   184                                                   """Get a count of data entries in each train.
   185                                           
   186                                                   If *labelled* is True, returns a pandas series with an index of train
   187                                                   IDs. Otherwise, returns a NumPy array of counts to match ``.train_ids``.
   188                                                   """
   189        79    8694366.0 110055.3     99.7          if self._data_chunks:
   190        79       2164.0     27.4      0.0              train_ids = np.concatenate([c.train_ids for c in self._data_chunks])
   191        79       1930.0     24.4      0.0              counts = np.concatenate([c.counts for c in self._data_chunks])
   192                                                   else:
   193                                                       train_ids = counts = np.zeros(0, dtype=np.uint64)
   194                                           
   195        79         86.0      1.1      0.0          if labelled:
   196        79        160.0      2.0      0.0              import pandas as pd
   197        79      25897.0    327.8      0.3              return pd.Series(counts, index=train_ids)
   198                                                   else:
   199                                                       all_tids_arr = np.array(self.train_ids)
   200                                                       res = np.zeros(len(all_tids_arr), dtype=np.uint64)
   201                                                       tid_to_ix = np.intersect1d(all_tids_arr, train_ids, return_indices=True)[1]
   202                                           
   203                                                       # We may be missing some train IDs, if they're not in any file
   204                                                       # for this source, and they're sometimes out of order within chunks
   205                                                       # (they shouldn't be, but we try not to fail too badly if they are).
   206                                                       # assert np.isin(train_ids, all_tids_arr).all()
   207                                                       assert len(tid_to_ix) == len(train_ids)
   208                                                       res[tid_to_ix] = counts
   209                                           
   210                                                       return res

Total time: 15.8437 s
File: /home/wrigleyj/src/EXtra-data/extra_data/reader.py
Function: plot_missing_data at line 1333

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
  1333                                               def plot_missing_data(self, min_saved_pct=95):
  1334                                                   """Plot sources that have missing data for some trains.
  1335                                           
  1336                                                   Parameters
  1337                                                   ----------
  1338                                           
  1339                                                   min_saved_pct: int or float, optional
  1340                                                       Only show sources with less than this percentage of trains saved.
  1341                                                   """
  1342         1          4.0      4.0      0.0          n_trains = len(self.train_ids)
  1343                                           
  1344                                                   # Helper function that returns an alias for a source if one is
  1345                                                   # available, and the source name otherwise.
  1346         1          3.0      3.0      0.0          def best_src_name(src):
  1347                                                       for alias, alias_ident in self._aliases.items():
  1348                                                           if isinstance(alias_ident, str) and alias_ident == src:
  1349                                                               return alias
  1350                                           
  1351                                                       return src
  1352                                           
  1353                                                   # If possible, create a progress bar. Loading the train IDs for every
  1354                                                   # source can be very slow (tens of seconds) so it's good to give the
  1355                                                   # user some feedback to know that the method hasn't frozen.
  1356         1          2.0      2.0      0.0          display_progress = False
  1357         1          8.0      8.0      0.0          if self._running_in_notebook():
  1358         1          2.0      2.0      0.0              try:
  1359         1         11.0     11.0      0.0                  from ipywidgets import IntProgress
  1360         1          5.0      5.0      0.0                  from IPython.display import display
  1361                                                       except ImportError:
  1362                                                           pass
  1363                                                       else:
  1364         2      12863.0   6431.5      0.1                  progress_bar = IntProgress(min=0, max=len(self.all_sources),
  1365         1          2.0      2.0      0.0                                             description="Checking:")
  1366         1       4188.0   4188.0      0.0                  display(progress_bar)
  1367         1          3.0      3.0      0.0                  display_progress = True
  1368                                           
  1369                                                   # Find sources with missing data
  1370         1          2.0      2.0      0.0          flaky_sources = { }
  1371         1        490.0    490.0      0.0          run_tids = set(self.train_ids)
  1372        80        269.0      3.4      0.0          for src in self.all_sources:
  1373        79    5763773.0  72959.2     36.4              key = list(self[src].keys())[0]
  1374        79      29373.0    371.8      0.2              kd = self[src, key]
  1375        79    9242555.0 116994.4     58.3              kd_tids = kd.drop_empty_trains().train_ids
  1376        79        355.0      4.5      0.0              save_pct = len(kd_tids) / n_trains * 100
  1377                                           
  1378        79      60006.0    759.6      0.4              if set(kd_tids) != run_tids and save_pct <= min_saved_pct:
  1379         3        120.0     40.0      0.0                  flaky_sources[best_src_name(src)] = kd_tids
  1380                                           
  1381        79        173.0      2.2      0.0              if display_progress:
  1382        79      73030.0    924.4      0.5                  progress_bar.value += 1
  1383                                           
  1384                                                   # Hide the progress bar now that we've checked all the sources
  1385         1          2.0      2.0      0.0          if display_progress:
  1386         1        601.0    601.0      0.0              progress_bar.close()
  1387                                           
  1388                                                   # Sort the flaky sources by decreasing order of how many trains they're missing
  1389         2         12.0      6.0      0.0          flaky_sources = dict(sorted(flaky_sources.items(), key=lambda x: len(x[1]),
  1390         1          2.0      2.0      0.0                                      reverse=True)
  1391                                                                        )
  1392                                           
  1393                                                   # Plot missing data
  1394         1          7.0      7.0      0.0          import matplotlib.pyplot as plt
  1395         1      45435.0  45435.0      0.3          fig, ax = plt.subplots(figsize=(9, max(2, len(flaky_sources) / 4)))
  1396                                           
  1397         1          3.0      3.0      0.0          bar_height = 0.5
  1398         4         15.0      3.8      0.0          for i, src in enumerate(flaky_sources):
  1399                                                       # First find all the trains that are missing
  1400         3        200.0     66.7      0.0              save_line = np.zeros(n_trains).astype(bool)
  1401         3       4140.0   1380.0      0.0              save_line[np.intersect1d(self.train_ids, flaky_sources[src], return_indices=True)[1]] = True
  1402                                           
  1403                                                       # Loop over each train to find blocks of trains that are either
  1404                                                       # present or missing.
  1405         3         81.0     27.0      0.0              bars = { }
  1406         3          6.0      2.0      0.0              block_start = 0
  1407     36558      66451.0      1.8      0.4              for idx in range(n_trains):
  1408     36555     116564.0      3.2      0.7                  if save_line[idx] != save_line[block_start]:
  1409                                                               # If we find a train that doesn't match the save status of
  1410                                                               # the current block, create a new entry in `bars` to record
  1411                                                               # the start index, the block length, and the save status.
  1412      3252       6715.0      2.1      0.0                      bars[(block_start, idx - block_start)] = save_line[block_start]
  1413      3252       7148.0      2.2      0.0                      block_start = idx
  1414                                           
  1415                                                       # Add the last block
  1416         3          7.0      2.3      0.0              bars[(block_start, n_trains - block_start)] = save_line[block_start]
  1417                                           
  1418                                                       # Plot all the blocks
  1419         6     278591.0  46431.8      1.8              ax.broken_barh(bars.keys(),
  1420         3          6.0      2.0      0.0                             (i, bar_height),
  1421         3        521.0    173.7      0.0                             color=["k" if x else "r" for x in bars.values()])
  1422                                           
  1423                                                   # Set labels and ticks
  1424         2         17.0      8.5      0.0          tick_labels = [f"{src} ({len(tids) / n_trains * 100:.2f}%)"
  1425         1          3.0      3.0      0.0                    for i, (src, tids) in enumerate(flaky_sources.items())]
  1426         2       5488.0   2744.0      0.0          ax.set_yticks(np.arange(len(flaky_sources)) + bar_height / 2,
  1427         1          2.0      2.0      0.0                        labels=tick_labels, fontsize=8)
  1428         1        108.0    108.0      0.0          ax.set_xlabel("Train ID index")
  1429                                           
  1430                                                   # Set title
  1431         1          4.0      4.0      0.0          title = f"Sources with less than {min_saved_pct}% of trains saved"
  1432         1       4511.0   4511.0      0.0          run_meta = self.run_metadata()
  1433         1          3.0      3.0      0.0          if "proposalNumber" in run_meta and "runNumber" in run_meta:
  1434         1          8.0      8.0      0.0              title += f" in p{run_meta['proposalNumber']}, run {run_meta['runNumber']}"
  1435         1        376.0    376.0      0.0          ax.set_title(title)
  1436                                           
  1437         1     119419.0 119419.0      0.8          fig.tight_layout()
  1438                                           
  1439         1          2.0      2.0      0.0          return ax

Both of the runs are roughly the same length and have the same sources. So voview files already help a ton, and when there are lots of files the time seems to be dominated by opening groups (see the profile for KeyData._find_chunks(), which is called by KeyData._data_chunks(), which is called by KeyData.data_counts(), which is called by KeyData.drop_empty_trains(). BTW this is all after I made some optimizations to KeyData.drop_empty_trains(labelled=False), it's ~3.8x faster now (but of course that doesn't help too much on the first execution where IO dominates).

The question is whether we're ok with making the user wait up to ~20s 🤔 I lean towards no and keeping the progress bar, but I don't feel very strongly about it.

@tmichela
Copy link
Member

What about using tqdm? So you don't have to play with widgets and trying to figure out yourself if you're in a notebook (+ the progress bar also works if you're not in a notebook).

@takluyver
Copy link
Member

Nice! Thanks for measuring that. Nearly 5 seconds with a single 42 MB file is dismal. 😞

We only need 1 key for this, right? I wonder if we can speed it up by adding a method on FileAccess to get one arbitrary key from a given source, rather than finding all keys? I remember we made key in source tests faster by avoiding getting a list of all keys (see the has_source_key() method). If so, I think some of @philsmt's recent PRs might benefit from this as well.

If that works for the cases with a virtual overview file, I lean towards checking something like if len(self.files) > 10, and if so, printing a message saying 'This may take a minute...'. Hopefully if we can make virtual overview files more widespread (including making them for proc data in the calibration pipeline), we can also make this mostly unnecessary. I also don't feel strongly about this, though; if we do go for a progress bar, I think @tmichela's idea of using tqdm is a good one.

@takluyver
Copy link
Member

I tested that theory out, got it down to about 30-50 ms before using any caches:

image

Initial `get_one_key` function

def get_one_key(source):
    self = run.files[0]
    if source in self.control_sources:
        group = '/CONTROL/' + source
    elif source in self.instrument_sources:
        group = '/INSTRUMENT/' + source
    else:
        raise SourceNameError(source)

    res = set()

    def add_key(key, value):
        if isinstance(value, h5py.Dataset):
            return key.replace('/', '.')

    return self.file[group].visititems(add_key)

@JamesWrigley
Copy link
Member Author

Oooo lovely 😀 I'll try using that, and if it helps a bit then we can probably get rid of the progress bars. BTW @tmichela I looked into adding markers for the sequence files, but it'd be kinda painful to follow the trains through a voview file so I decided not to.

@JamesWrigley
Copy link
Member Author

Big improvement with a single voview file \o/

p3422, run 158. Total running time: 1.36s, bottleneck is KeyData.drop_empty_trains()
Timer unit: 1e-06 s

Total time: 0.141316 s
File: /home/wrigleyj/src/EXtra-data/extra_data/file_access.py
Function: _read_index at line 330

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   330                                               def _read_index(self, source, group):
   331                                                   """Get first index & count for a source.
   332                                           
   333                                                   This is 'real' reading when the requested index is not in the cache.
   334                                                   """
   335        79         71.0      0.9      0.1          ntrains = len(self.train_ids)
   336        79      95768.0   1212.3     67.8          ix_group = self.file['/INDEX/{}/{}'.format(source, group)]
   337        79      22610.0    286.2     16.0          firsts = ix_group['first'][:ntrains]
   338        79       2490.0     31.5      1.8          if 'count' in ix_group:
   339        79      20280.0    256.7     14.4              counts = ix_group['count'][:ntrains]
   340                                                   else:
   341                                                       status = ix_group['status'][:ntrains]
   342                                                       counts = np.uint64((ix_group['last'][:ntrains] - firsts + 1) * status)
   343        79         97.0      1.2      0.1          return firsts, counts

Total time: 0 s
File: /home/wrigleyj/src/EXtra-data/extra_data/file_access.py
Function: get_keys at line 373

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   373                                               def get_keys(self, source):
   374                                                   """Get keys for a given source name
   375                                           
   376                                                   Keys are found by walking the HDF5 file, and cached for reuse.
   377                                                   """
   378                                                   try:
   379                                                       return self._keys_cache[source]
   380                                                   except KeyError:
   381                                                       pass
   382                                           
   383                                                   if source in self.control_sources:
   384                                                       group = '/CONTROL/' + source
   385                                                   elif source in self.instrument_sources:
   386                                                       group = '/INSTRUMENT/' + source
   387                                                   else:
   388                                                       raise SourceNameError(source)
   389                                           
   390                                                   res = set()
   391                                           
   392                                                   def add_key(key, value):
   393                                                       if isinstance(value, h5py.Dataset):
   394                                                           res.add(key.replace('/', '.'))
   395                                           
   396                                                   self.file[group].visititems(add_key)
   397                                                   self._keys_cache[source] = res
   398                                                   return res

Total time: 0.229125 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: _find_chunks at line 30

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    30                                               def _find_chunks(self):
    31                                                   """Find contiguous chunks of data for this key, in any order."""
    32        79      40284.0    509.9     17.6          all_tids_arr = np.array(self.train_ids)
    33                                           
    34       158        157.0      1.0      0.1          for file in self.files:
    35        79        116.0      1.5      0.1              if len(file.train_ids) == 0:
    36                                                           continue
    37                                           
    38        79     142863.0   1808.4     62.4              firsts, counts = file.get_index(self.source, self._key_group)
    39                                           
    40                                                       # Of trains in this file, which are in selection
    41        79      29727.0    376.3     13.0              include = np.isin(file.train_ids, all_tids_arr)
    42        79        132.0      1.7      0.1              if not self.inc_suspect_trains:
    43                                                           include &= file.validity_flag
    44                                           
    45                                                       # Assemble contiguous chunks of data from this file
    46       158      14722.0     93.2      6.4              for _from, _to in contiguous_regions(include):
    47       158        463.0      2.9      0.2                  yield DataChunk(
    48        79        333.0      4.2      0.1                      file, self.hdf5_data_path,
    49        79        113.0      1.4      0.0                      first=firsts[_from],
    50        79        127.0      1.6      0.1                      train_ids=file.train_ids[_from:_to],
    51        79         88.0      1.1      0.0                      counts=counts[_from:_to],
    52                                                           )

Total time: 0.546617 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: drop_empty_trains at line 146

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   146                                               def drop_empty_trains(self):
   147                                                   """Select only trains with data as a new :class:`KeyData` object."""
   148        79     259825.0   3288.9     47.5          counts = self.data_counts()
   149                                                   # Note: we do this strange UInt64Index -> ndarray -> list conversion
   150                                                   # because for some reason panda's .to_list() method generates
   151                                                   # sub-optimal lists. If .to_list() is used, then operations on the
   152                                                   # train_id's list of the new KeyData object are significantly slower
   153                                                   # (even the call to _only_tids() below).
   154        79      78017.0    987.6     14.3          tids = list(counts[counts > 0].index.to_numpy())
   155                                           
   156        79     208775.0   2642.7     38.2          return self._only_tids(list(tids))

Total time: 0.258895 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: data_counts at line 183

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   183                                               def data_counts(self, labelled=True):
   184                                                   """Get a count of data entries in each train.
   185                                           
   186                                                   If *labelled* is True, returns a pandas series with an index of train
   187                                                   IDs. Otherwise, returns a NumPy array of counts to match ``.train_ids``.
   188                                                   """
   189        79     231133.0   2925.7     89.3          if self._data_chunks:
   190        79       1712.0     21.7      0.7              train_ids = np.concatenate([c.train_ids for c in self._data_chunks])
   191        79        965.0     12.2      0.4              counts = np.concatenate([c.counts for c in self._data_chunks])
   192                                                   else:
   193                                                       train_ids = counts = np.zeros(0, dtype=np.uint64)
   194                                           
   195        79         73.0      0.9      0.0          if labelled:
   196        79        167.0      2.1      0.1              import pandas as pd
   197        79      24845.0    314.5      9.6              return pd.Series(counts, index=train_ids)
   198                                                   else:
   199                                                       all_tids_arr = np.array(self.train_ids)
   200                                                       res = np.zeros(len(all_tids_arr), dtype=np.uint64)
   201                                                       tid_to_ix = np.intersect1d(all_tids_arr, train_ids, return_indices=True)[1]
   202                                           
   203                                                       # We may be missing some train IDs, if they're not in any file
   204                                                       # for this source, and they're sometimes out of order within chunks
   205                                                       # (they shouldn't be, but we try not to fail too badly if they are).
   206                                                       # assert np.isin(train_ids, all_tids_arr).all()
   207                                                       assert len(tid_to_ix) == len(train_ids)
   208                                                       res[tid_to_ix] = counts
   209                                           
   210                                                       return res

Total time: 1.35892 s
File: /home/wrigleyj/src/EXtra-data/extra_data/reader.py
Function: plot_missing_data at line 1333

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
  1333                                               def plot_missing_data(self, min_saved_pct=95):
  1334                                                   """Plot sources that have missing data for some trains.
  1335                                           
  1336                                                   Parameters
  1337                                                   ----------
  1338                                           
  1339                                                   min_saved_pct: int or float, optional
  1340                                                       Only show sources with less than this percentage of trains saved.
  1341                                                   """
  1342         1          3.0      3.0      0.0          n_trains = len(self.train_ids)
  1343                                           
  1344                                                   # Helper function that returns an alias for a source if one is
  1345                                                   # available, and the source name otherwise.
  1346         1          2.0      2.0      0.0          def best_src_name(src):
  1347                                                       for alias, alias_ident in self._aliases.items():
  1348                                                           if isinstance(alias_ident, str) and alias_ident == src:
  1349                                                               return alias
  1350                                           
  1351                                                       return src
  1352                                           
  1353                                                   # If possible, create a progress bar. Loading the train IDs for every
  1354                                                   # source can be very slow (tens of seconds) so it's good to give the
  1355                                                   # user some feedback to know that the method hasn't frozen.
  1356         1          2.0      2.0      0.0          display_progress = False
  1357         1          7.0      7.0      0.0          if self._running_in_notebook():
  1358         1          2.0      2.0      0.0              try:
  1359         1          8.0      8.0      0.0                  from ipywidgets import IntProgress
  1360         1          5.0      5.0      0.0                  from IPython.display import display
  1361                                                       except ImportError:
  1362                                                           pass
  1363                                                       else:
  1364         2      11445.0   5722.5      0.8                  progress_bar = IntProgress(min=0, max=len(self.all_sources),
  1365         1          2.0      2.0      0.0                                             description="Checking:")
  1366         1       3614.0   3614.0      0.3                  display(progress_bar)
  1367         1          3.0      3.0      0.0                  display_progress = True
  1368                                           
  1369                                                   # Find sources with missing data
  1370         1          2.0      2.0      0.0          flaky_sources = { }
  1371         1        524.0    524.0      0.0          run_tids = set(self.train_ids)
  1372        80        246.0      3.1      0.0          for src in self.all_sources:
  1373        79      54499.0    689.9      4.0              key = self[src].files[0].get_one_key(src)
  1374        79      52988.0    670.7      3.9              kd = self[src, key]
  1375        79     556458.0   7043.8     40.9              kd_tids = kd.drop_empty_trains().train_ids
  1376        79        361.0      4.6      0.0              save_pct = len(kd_tids) / n_trains * 100
  1377                                           
  1378        79      62500.0    791.1      4.6              if set(kd_tids) != run_tids and save_pct <= min_saved_pct:
  1379         3        115.0     38.3      0.0                  flaky_sources[best_src_name(src)] = kd_tids
  1380                                           
  1381        79        171.0      2.2      0.0              if display_progress:
  1382        79      64122.0    811.7      4.7                  progress_bar.value += 1
  1383                                           
  1384                                                   # Hide the progress bar now that we've checked all the sources
  1385         1          2.0      2.0      0.0          if display_progress:
  1386         1        473.0    473.0      0.0              progress_bar.close()
  1387                                           
  1388                                                   # Sort the flaky sources by decreasing order of how many trains they're missing
  1389         2         11.0      5.5      0.0          flaky_sources = dict(sorted(flaky_sources.items(), key=lambda x: len(x[1]),
  1390         1          2.0      2.0      0.0                                      reverse=True)
  1391                                                                        )
  1392                                           
  1393                                                   # Plot missing data
  1394         1          6.0      6.0      0.0          import matplotlib.pyplot as plt
  1395         1      45521.0  45521.0      3.3          fig, ax = plt.subplots(figsize=(9, max(2, len(flaky_sources) / 4)))
  1396                                           
  1397         1          2.0      2.0      0.0          bar_height = 0.5
  1398         4         14.0      3.5      0.0          for i, src in enumerate(flaky_sources):
  1399                                                       # First find all the trains that are missing
  1400         3        148.0     49.3      0.0              save_line = np.zeros(n_trains).astype(bool)
  1401         3       4005.0   1335.0      0.3              save_line[np.intersect1d(self.train_ids, flaky_sources[src], return_indices=True)[1]] = True
  1402                                           
  1403                                                       # Loop over each train to find blocks of trains that are either
  1404                                                       # present or missing.
  1405         3         80.0     26.7      0.0              bars = { }
  1406         3          6.0      2.0      0.0              block_start = 0
  1407     36582      64400.0      1.8      4.7              for idx in range(n_trains):
  1408     36579     103342.0      2.8      7.6                  if save_line[idx] != save_line[block_start]:
  1409                                                               # If we find a train that doesn't match the save status of
  1410                                                               # the current block, create a new entry in `bars` to record
  1411                                                               # the start index, the block length, and the save status.
  1412      3276       6328.0      1.9      0.5                      bars[(block_start, idx - block_start)] = save_line[block_start]
  1413      3276       5792.0      1.8      0.4                      block_start = idx
  1414                                           
  1415                                                       # Add the last block
  1416         3          7.0      2.3      0.0              bars[(block_start, n_trains - block_start)] = save_line[block_start]
  1417                                           
  1418                                                       # Plot all the blocks
  1419         6     248346.0  41391.0     18.3              ax.broken_barh(bars.keys(),
  1420         3          6.0      2.0      0.0                             (i, bar_height),
  1421         3        440.0    146.7      0.0                             color=["k" if x else "r" for x in bars.values()])
  1422                                           
  1423                                                   # Set labels and ticks
  1424         2         16.0      8.0      0.0          tick_labels = [f"{src} ({len(tids) / n_trains * 100:.2f}%)"
  1425         1          3.0      3.0      0.0                    for i, (src, tids) in enumerate(flaky_sources.items())]
  1426         2       5428.0   2714.0      0.4          ax.set_yticks(np.arange(len(flaky_sources)) + bar_height / 2,
  1427         1          2.0      2.0      0.0                        labels=tick_labels, fontsize=8)
  1428         1        115.0    115.0      0.0          ax.set_xlabel("Train ID index")
  1429                                           
  1430                                                   # Set title
  1431         1          4.0      4.0      0.0          title = f"Sources with less than {min_saved_pct}% of trains saved"
  1432         1         11.0     11.0      0.0          run_meta = self.run_metadata()
  1433         1          3.0      3.0      0.0          if "proposalNumber" in run_meta and "runNumber" in run_meta:
  1434                                                       title += f" in p{run_meta['proposalNumber']}, run {run_meta['runNumber']}"
  1435         1        353.0    353.0      0.0          ax.set_title(title)
  1436                                           
  1437         1      66971.0  66971.0      4.9          fig.tight_layout()
  1438                                           
  1439         1          2.0      2.0      0.0          return ax

And it shaves a couple seconds off for a run with 769 files:

p3422, run 159. Total running time: 12.7s, bottleneck is KeyData.drop_empty_trains()
Timer unit: 1e-06 s

Total time: 10.0523 s
File: /home/wrigleyj/src/EXtra-data/extra_data/file_access.py
Function: _read_index at line 330

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   330                                               def _read_index(self, source, group):
   331                                                   """Get first index & count for a source.
   332                                           
   333                                                   This is 'real' reading when the requested index is not in the cache.
   334                                                   """
   335       831        957.0      1.2      0.0          ntrains = len(self.train_ids)
   336       831    9451006.0  11373.1     94.0          ix_group = self.file['/INDEX/{}/{}'.format(source, group)]
   337       831     298174.0    358.8      3.0          firsts = ix_group['first'][:ntrains]
   338       831      29058.0     35.0      0.3          if 'count' in ix_group:
   339       831     272018.0    327.3      2.7              counts = ix_group['count'][:ntrains]
   340                                                   else:
   341                                                       status = ix_group['status'][:ntrains]
   342                                                       counts = np.uint64((ix_group['last'][:ntrains] - firsts + 1) * status)
   343       831       1125.0      1.4      0.0          return firsts, counts

Total time: 0 s
File: /home/wrigleyj/src/EXtra-data/extra_data/file_access.py
Function: get_keys at line 373

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   373                                               def get_keys(self, source):
   374                                                   """Get keys for a given source name
   375                                           
   376                                                   Keys are found by walking the HDF5 file, and cached for reuse.
   377                                                   """
   378                                                   try:
   379                                                       return self._keys_cache[source]
   380                                                   except KeyError:
   381                                                       pass
   382                                           
   383                                                   if source in self.control_sources:
   384                                                       group = '/CONTROL/' + source
   385                                                   elif source in self.instrument_sources:
   386                                                       group = '/INSTRUMENT/' + source
   387                                                   else:
   388                                                       raise SourceNameError(source)
   389                                           
   390                                                   res = set()
   391                                           
   392                                                   def add_key(key, value):
   393                                                       if isinstance(value, h5py.Dataset):
   394                                                           res.add(key.replace('/', '.'))
   395                                           
   396                                                   self.file[group].visititems(add_key)
   397                                                   self._keys_cache[source] = res
   398                                                   return res

Total time: 10.5948 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: _find_chunks at line 30

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    30                                               def _find_chunks(self):
    31                                                   """Find contiguous chunks of data for this key, in any order."""
    32        79      41300.0    522.8      0.4          all_tids_arr = np.array(self.train_ids)
    33                                           
    34       910        949.0      1.0      0.0          for file in self.files:
    35       831       1650.0      2.0      0.0              if len(file.train_ids) == 0:
    36                                                           continue
    37                                           
    38       831   10071033.0  12119.2     95.1              firsts, counts = file.get_index(self.source, self._key_group)
    39                                           
    40                                                       # Of trains in this file, which are in selection
    41       831     315618.0    379.8      3.0              include = np.isin(file.train_ids, all_tids_arr)
    42       831       1359.0      1.6      0.0              if not self.inc_suspect_trains:
    43                                                           include &= file.validity_flag
    44                                           
    45                                                       # Assemble contiguous chunks of data from this file
    46      1662     151364.0     91.1      1.4              for _from, _to in contiguous_regions(include):
    47      1662       5014.0      3.0      0.0                  yield DataChunk(
    48       831       3124.0      3.8      0.0                      file, self.hdf5_data_path,
    49       831       1140.0      1.4      0.0                      first=firsts[_from],
    50       831       1453.0      1.7      0.0                      train_ids=file.train_ids[_from:_to],
    51       831        773.0      0.9      0.0                      counts=counts[_from:_to],
    52                                                           )

Total time: 11.1658 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: drop_empty_trains at line 146

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   146                                               def drop_empty_trains(self):
   147                                                   """Select only trains with data as a new :class:`KeyData` object."""
   148        79   10641516.0 134702.7     95.3          counts = self.data_counts()
   149                                                   # Note: we do this strange UInt64Index -> ndarray -> list conversion
   150                                                   # because for some reason panda's .to_list() method generates
   151                                                   # sub-optimal lists. If .to_list() is used, then operations on the
   152                                                   # train_id's list of the new KeyData object are significantly slower
   153                                                   # (even the call to _only_tids() below).
   154        79      89195.0   1129.1      0.8          tids = list(counts[counts > 0].index.to_numpy())
   155                                           
   156        79     435088.0   5507.4      3.9          return self._only_tids(list(tids))

Total time: 10.6406 s
File: /home/wrigleyj/src/EXtra-data/extra_data/keydata.py
Function: data_counts at line 183

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   183                                               def data_counts(self, labelled=True):
   184                                                   """Get a count of data entries in each train.
   185                                           
   186                                                   If *labelled* is True, returns a pandas series with an index of train
   187                                                   IDs. Otherwise, returns a NumPy array of counts to match ``.train_ids``.
   188                                                   """
   189        79   10607080.0 134266.8     99.7          if self._data_chunks:
   190        79       2262.0     28.6      0.0              train_ids = np.concatenate([c.train_ids for c in self._data_chunks])
   191        79       1620.0     20.5      0.0              counts = np.concatenate([c.counts for c in self._data_chunks])
   192                                                   else:
   193                                                       train_ids = counts = np.zeros(0, dtype=np.uint64)
   194                                           
   195        79         79.0      1.0      0.0          if labelled:
   196        79        187.0      2.4      0.0              import pandas as pd
   197        79      29357.0    371.6      0.3              return pd.Series(counts, index=train_ids)
   198                                                   else:
   199                                                       all_tids_arr = np.array(self.train_ids)
   200                                                       res = np.zeros(len(all_tids_arr), dtype=np.uint64)
   201                                                       tid_to_ix = np.intersect1d(all_tids_arr, train_ids, return_indices=True)[1]
   202                                           
   203                                                       # We may be missing some train IDs, if they're not in any file
   204                                                       # for this source, and they're sometimes out of order within chunks
   205                                                       # (they shouldn't be, but we try not to fail too badly if they are).
   206                                                       # assert np.isin(train_ids, all_tids_arr).all()
   207                                                       assert len(tid_to_ix) == len(train_ids)
   208                                                       res[tid_to_ix] = counts
   209                                           
   210                                                       return res

Total time: 12.7379 s
File: /home/wrigleyj/src/EXtra-data/extra_data/reader.py
Function: plot_missing_data at line 1333

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
  1333                                               def plot_missing_data(self, min_saved_pct=95):
  1334                                                   """Plot sources that have missing data for some trains.
  1335                                           
  1336                                                   Parameters
  1337                                                   ----------
  1338                                           
  1339                                                   min_saved_pct: int or float, optional
  1340                                                       Only show sources with less than this percentage of trains saved.
  1341                                                   """
  1342         1          6.0      6.0      0.0          n_trains = len(self.train_ids)
  1343                                           
  1344                                                   # Helper function that returns an alias for a source if one is
  1345                                                   # available, and the source name otherwise.
  1346         1          5.0      5.0      0.0          def best_src_name(src):
  1347                                                       for alias, alias_ident in self._aliases.items():
  1348                                                           if isinstance(alias_ident, str) and alias_ident == src:
  1349                                                               return alias
  1350                                           
  1351                                                       return src
  1352                                           
  1353                                                   # If possible, create a progress bar. Loading the train IDs for every
  1354                                                   # source can be very slow (tens of seconds) so it's good to give the
  1355                                                   # user some feedback to know that the method hasn't frozen.
  1356         1          2.0      2.0      0.0          display_progress = False
  1357         1         10.0     10.0      0.0          if self._running_in_notebook():
  1358         1          2.0      2.0      0.0              try:
  1359         1         16.0     16.0      0.0                  from ipywidgets import IntProgress
  1360         1          5.0      5.0      0.0                  from IPython.display import display
  1361                                                       except ImportError:
  1362                                                           pass
  1363                                                       else:
  1364         2      11834.0   5917.0      0.1                  progress_bar = IntProgress(min=0, max=len(self.all_sources),
  1365         1          2.0      2.0      0.0                                             description="Checking:")
  1366         1       3824.0   3824.0      0.0                  display(progress_bar)
  1367         1          2.0      2.0      0.0                  display_progress = True
  1368                                           
  1369                                                   # Find sources with missing data
  1370         1          2.0      2.0      0.0          flaky_sources = { }
  1371         1        532.0    532.0      0.0          run_tids = set(self.train_ids)
  1372        80        256.0      3.2      0.0          for src in self.all_sources:
  1373        79     777193.0   9837.9      6.1              key = self[src].files[0].get_one_key(src)
  1374        79      47595.0    602.5      0.4              kd = self[src, key]
  1375        79   11178431.0 141499.1     87.8              kd_tids = kd.drop_empty_trains().train_ids
  1376        79        394.0      5.0      0.0              save_pct = len(kd_tids) / n_trains * 100
  1377                                           
  1378        79      65971.0    835.1      0.5              if set(kd_tids) != run_tids and save_pct <= min_saved_pct:
  1379         3        128.0     42.7      0.0                  flaky_sources[best_src_name(src)] = kd_tids
  1380                                           
  1381        79        188.0      2.4      0.0              if display_progress:
  1382        79      71498.0    905.0      0.6                  progress_bar.value += 1
  1383                                           
  1384                                                   # Hide the progress bar now that we've checked all the sources
  1385         1          2.0      2.0      0.0          if display_progress:
  1386         1        468.0    468.0      0.0              progress_bar.close()
  1387                                           
  1388                                                   # Sort the flaky sources by decreasing order of how many trains they're missing
  1389         2         10.0      5.0      0.0          flaky_sources = dict(sorted(flaky_sources.items(), key=lambda x: len(x[1]),
  1390         1          2.0      2.0      0.0                                      reverse=True)
  1391                                                                        )
  1392                                           
  1393                                                   # Plot missing data
  1394         1         10.0     10.0      0.0          import matplotlib.pyplot as plt
  1395         1      51996.0  51996.0      0.4          fig, ax = plt.subplots(figsize=(9, max(2, len(flaky_sources) / 4)))
  1396                                           
  1397         1          3.0      3.0      0.0          bar_height = 0.5
  1398         4         14.0      3.5      0.0          for i, src in enumerate(flaky_sources):
  1399                                                       # First find all the trains that are missing
  1400         3        165.0     55.0      0.0              save_line = np.zeros(n_trains).astype(bool)
  1401         3       4249.0   1416.3      0.0              save_line[np.intersect1d(self.train_ids, flaky_sources[src], return_indices=True)[1]] = True
  1402                                           
  1403                                                       # Loop over each train to find blocks of trains that are either
  1404                                                       # present or missing.
  1405         3        117.0     39.0      0.0              bars = { }
  1406         3          8.0      2.7      0.0              block_start = 0
  1407     36558      65072.0      1.8      0.5              for idx in range(n_trains):
  1408     36555     107148.0      2.9      0.8                  if save_line[idx] != save_line[block_start]:
  1409                                                               # If we find a train that doesn't match the save status of
  1410                                                               # the current block, create a new entry in `bars` to record
  1411                                                               # the start index, the block length, and the save status.
  1412      3252       6292.0      1.9      0.0                      bars[(block_start, idx - block_start)] = save_line[block_start]
  1413      3252       7226.0      2.2      0.1                      block_start = idx
  1414                                           
  1415                                                       # Add the last block
  1416         3          7.0      2.3      0.0              bars[(block_start, n_trains - block_start)] = save_line[block_start]
  1417                                           
  1418                                                       # Plot all the blocks
  1419         6     244122.0  40687.0      1.9              ax.broken_barh(bars.keys(),
  1420         3          5.0      1.7      0.0                             (i, bar_height),
  1421         3        434.0    144.7      0.0                             color=["k" if x else "r" for x in bars.values()])
  1422                                           
  1423                                                   # Set labels and ticks
  1424         2         16.0      8.0      0.0          tick_labels = [f"{src} ({len(tids) / n_trains * 100:.2f}%)"
  1425         1          3.0      3.0      0.0                    for i, (src, tids) in enumerate(flaky_sources.items())]
  1426         2       5877.0   2938.5      0.0          ax.set_yticks(np.arange(len(flaky_sources)) + bar_height / 2,
  1427         1          2.0      2.0      0.0                        labels=tick_labels, fontsize=8)
  1428         1        110.0    110.0      0.0          ax.set_xlabel("Train ID index")
  1429                                           
  1430                                                   # Set title
  1431         1          4.0      4.0      0.0          title = f"Sources with less than {min_saved_pct}% of trains saved"
  1432         1       4514.0   4514.0      0.0          run_meta = self.run_metadata()
  1433         1          3.0      3.0      0.0          if "proposalNumber" in run_meta and "runNumber" in run_meta:
  1434         1          9.0      9.0      0.0              title += f" in p{run_meta['proposalNumber']}, run {run_meta['runNumber']}"
  1435         1        406.0    406.0      0.0          ax.set_title(title)
  1436                                           
  1437         1      81697.0  81697.0      0.6          fig.tight_layout()
  1438                                           
  1439         1          3.0      3.0      0.0          return ax

And I think I'll go with a print() statement instead of a progress bar, I prefer that over adding a dependency.

@JamesWrigley JamesWrigley force-pushed the missing_trains branch 2 times, most recently from 25ef861 to d2a7219 Compare May 25, 2023 21:35
@JamesWrigley
Copy link
Member Author

JamesWrigley commented May 25, 2023

Added SourceData.one_key() in 1d9d269, and some optimizations in cf62f50 and 9f20890.

Changes in d2a7219:

  • Replace progress bar with a message printed after 2s.
  • Add support for checking subsections of XTDF sources.
  • Add a secondary sorting by source name.
  • Add a legend.
  • Swap the missing/present colors.

I played around with different colors but felt that red/black was still the best combination. On second thoughts I agree with you @takluyver about black==missing being more intuitive so I swapped the colors, and I added a legend for good measure.

Example plot:
image

Aligning the title and legend is kinda tricky. It still looks a bit off with a zillion entries, but in typical cases it looks ok:
image

extra_data/file_access.py Dismissed Show dismissed Hide dismissed
@@ -139,6 +139,18 @@
for f in self.files:
return f.get_keys(self.source)

def one_key(self):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
@@ -467,6 +467,21 @@ def add_key(key, value):
self._keys_cache[source] = res
return res

def get_one_key(self, source):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also have two pieces of caching we could use in FileAccess to potentially avoid accessing the file. fa._keys_cache stores complete sets of keys as returned by fa.keys(), and fa._known_keys stores partial sets as used by key in source.

It won't make a difference for plot_missing_data() as you're only looking at each source once, but it's probably worth using, whether we add it in this PR or separately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, fixed in 029011e.

@takluyver
Copy link
Member

Thanks James, this LGTM.

I wonder about changing the red to a less alarm-y colour, like blue or green. My first guess was black=gap, but other people might guess the opposite. But I don't think this is important - we can always change it later if people express confusion.

@takluyver takluyver added this to the 1.13 milestone May 26, 2023
KeyData.data_counts() can return either a Series or ndarray, and for the ndarray
case there was a lot of manual work being done to make sure that the counts
array could handle missing trains and trains that are out-of-order, which was
relatively slow.

For a source with 85% of trains missing in a run with ~12,000 trains,
KeyData.data_counts(labelled=False) took 5.5ms. List of optimizations and
measurements after being applied:
- Vectorizing everything instead of building a dict: 2.3ms
- Using `assume_unique=True` in np.intersect1d(): 2.1ms
- Using `all_tids_arr` instead of `self.train_ids` (a list): 1.48ms
- Re-use `tid_to_ix` in the assert statement instead of using np.isin: 1.2ms

Decided not to pass `assume_unique=True`, so the final time was 1.45ms, about
~3.8x faster than the original version. Note that returning a Series is about
three orders of magnitude faster at about 122us.
This helps performance since the list doesn't have to be repeatedly converted
into an ndarray (or whatever numpy does) before being used.
Helper function for getting a single key, which can be a lot faster than getting
all of them with SourceData.keys().
@JamesWrigley
Copy link
Member Author

So the original reason I went with red is because it offered more contrast to black compared with blue or green, but I did some experimenting and I think I have a new favourite: deeppink 😁

What do you think?
image

Another alternative is cyan:
image

@JamesWrigley
Copy link
Member Author

JamesWrigley commented May 30, 2023

Changes in 73dcd1e:

  • Changed the 'present' color to deeppink
  • Fixed some layout issues with the plot
New giant plot example:

image

@philsmt
Copy link
Contributor

philsmt commented May 30, 2023

This is really neat, thanks! Your questions mostly got answered already, but for completeness:

The check for missing trains might not be robust, it's basically doing this: set(source_tids) != set(run_tids). But that'll kinda break if train IDs repeat, does that ever happen?

As Thomas already explained, this is a broken file. And a broken DAQ.

The code doesn't check the trains for each source-key pair individually, only the first key for the source. Is that ok? I think it is because if I remember correctly the DAQ drops entire hashes at a time if any data in the hash is missing.

While there's some code in EXtra-data treating each key individually, I fully agree with Thomas that the EXDF file structure doesn't even provide for that concept. Control sources have a single index entry for all keys, instrument sources per index group. Thus, it must be assumed that all keys under such index groups (considering a single, empty one for control sources) have the same number of rows.

I called the method plot_missing_data() because technically it's the data for a source that's missing, not entire trains from the run. But I think everyone calls this problem 'missing trains' so maybe it should be named plot_missing_trains() instead? I don't have a strong opinion about that.

Indeed I somewhat dislike the missing data moniker a bit as well. There are simply slower sources on one hand, and there are intentionally sparse data sources with train-on-demand and in the future after (online or performed offline) data reduction. Maybe it's only my non-native ears, but it implies there should be data. How about plot_train_records or plot_data_records?

I want to verbally raise my eyebrow at reader.py growing in size again, but it doesn't seem useful to move this particular function out - unless we move it together with things like DataCollection.info() into a separate module and call it from there.

LGTM!

@JamesWrigley
Copy link
Member Author

Maybe it's only my non-native ears, but it implies there should be data. How about plot_train_records or plot_data_records?

Hmmm, I'm not entirely opposed to them but to me they don't quite describe what the function is doing: plotting data that isn't present. But I also can't think of anything better ATM.

@tmichela
Copy link
Member

What do you think?

I personally prefer cyan but that's up to you, although deeppink is very flashy on my monitor 🙈

@philsmt
Copy link
Contributor

philsmt commented Jun 1, 2023

Hmmm, I'm not entirely opposed to them but to me they don't quite describe what the function is doing: plotting data that isn't present. But I also can't think of anything better ATM.

Agree, I wasn't a huge fan yet either. Thinking a bit more about trying to express "absence without implying fault", a few words I found to possibly jump off (sorry if I went too far with a thesaurus 🙈)

  • plot_absent_trains
  • plot_data_presence
  • plot_train_occupancy

@takluyver
Copy link
Member

Let's try it with deeppink - we can always change it later.

I kind of like the name as is, even if it does imply the data should be there - I think it will be easy to find and remember this way. Of the alternatives offered, plot_data_presence is maybe the best option, but I would expect people to think of this plot as showing where data is missing, rather than where it's present. The default expectation is that most things we're recording are there for most trains, so it's deviations from that that are interesting.

@philsmt
Copy link
Contributor

philsmt commented Jun 1, 2023

The default expectation is that most things we're recording are there for most trains, so it's deviations from that that are interesting.

Currently I agree, but this is not true with data reduction. We have already run singular prototype experiments with raw data reduction, and I worry about the message when the best way to visualize that is with plot_missing_trains.

@JamesWrigley
Copy link
Member Author

How about plot_data_absence()?

@takluyver
Copy link
Member

We might ask Luca for input on this one.

start = time.time()
for src in self.all_sources:
kds = { }
if src.endswith(":xtdf"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While looking through the implementation in the context of plotting data counts, I noticed I had missed this earlier.

Any instrument source can have multiple index groups (the first part of the key), and there are already two non-XTDF instances in calibration with this. For that purpose SourceData.data_counts() (and SourceData.drop_missing_trains()) has an optional index_group to distinguish between those.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With xtdf detectors, trains with 0 data should have 0 across all index groups. Does this hold for the other cases? I.e. can we summarise the data presence for a source here, or do we need to split it out into N groups?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on how much behaviour of the current DAQ we want to encode. Purely from an EXDF point of view, any index group is separate, and any instrument source may have multiple index groups with differing counts.

In practical terms probably, but I honestly don't know.

  • A Karabo pipeline can have multiple index groups, but as the DAQ only accepts a single and structurally identical hash for each train, they cannot differ (all index groups are 0 or 1).
  • With XTDF, the protocol itself does not seem to mandate a non-zero pulseCount field in the header and what happens with the image section in this case. We should verify with ITDM that the DAQ does not add any train entry (neither header, detector, image nor trailer) when there's no frame data for a particular train, or whether it depends (detector intentionally sending XTDF data with 0 frames, no XTDF data all, online data reduction with empty pattern, etc).

Naturally the virtual sources we create in corrections are an entirely different matter. At the moment the REMI case always has the same number of entries for its index groups, but that is definitely something data reduction may change.

Would this have a significant performance impact in the current version?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not so concerned about the performance impact, I'm thinking more of how the information is presented, and the complexity of the code. Multiplying the number of bars when the different groups within a source have the same trains missing can make it harder to get a useful overview - see the example James showed in this comment.

We could load all the groups for a source, check if they match, and if so only draw one bar. But that starts to sound like a fair bit of complexity for a corner case that it sounds like we don't currently need.

Luckily the details of the plot aren't part of the API contract, so perhaps we can get away with doing something simple now and adding complexity later if necessary.

@takluyver takluyver merged commit 0e4a054 into master Jun 5, 2023
7 of 9 checks passed
@takluyver takluyver deleted the missing_trains branch June 5, 2023 15:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants