Skip to content

Commit

Permalink
Resolving conflicts after merging with earlier squashed commit
Browse files Browse the repository at this point in the history
  • Loading branch information
geojunky committed Sep 16, 2022
2 parents 66bfe3a + 810586e commit 1bbc5be
Show file tree
Hide file tree
Showing 11 changed files with 1,247 additions and 152 deletions.
93 changes: 51 additions & 42 deletions seismic/ASDFdatabase/_FederatedASDFDataSetImpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,17 @@
import numpy as np

from obspy.core import Stream, UTCDateTime
from obspy import read, Trace
import pyasdf
from pyasdf import ASDFDataSet
from pyasdf.exceptions import ASDFValueError
import ujson as json
from collections import defaultdict
import sqlite3
import psutil
import hashlib
from functools import partial
from seismic.ASDFdatabase.utils import MIN_DATE, MAX_DATE
import pickle as cPickle
import pandas as pd
from rtree import index
import traceback

logging.basicConfig()

Expand Down Expand Up @@ -138,9 +135,9 @@ def __init__(self, asdf_source, logger=None, single_item_read_limit_in_mb=1024):

def _load_corrections(self):
self.correction_files = []
self.correction_map_tree = defaultdict(lambda: defaultdict(list))
self.correction_map_bounds = defaultdict(lambda: defaultdict(list))
self.correction_map_values = defaultdict(lambda: defaultdict(list))
self.correction_map_tree = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
self.correction_map_bounds = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
self.correction_map_values = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

# check to see if corrections are to be applied
self.corrections_enabled = False
Expand All @@ -160,15 +157,17 @@ def _load_corrections(self):

if(len(fnames)): print('Loading corrections..')

dtypes = {'net':object, 'sta':object, 'loc':object, 'comp':object, 'date':object,
'clock_correction':object}
dtypes = {'net':str, 'sta':str, 'loc':str, 'comp':str, 'date':str,
'clock_correction':str}
for fname in fnames:
try:
df = pd.read_csv(fname, delimiter=',', header=0, dtype=dtypes)
df = pd.read_csv(fname, delimiter=',', header=0, dtype=dtypes, na_filter=False)

try:
corr_count = 0
for i in np.arange(len(df)):
net = df['net'][i]
sta = df['sta'][i]
loc = df['loc'][i]
corr = df['clock_correction'][i]

if(corr == 'NOXCOR'): continue
Expand All @@ -177,45 +176,50 @@ def _load_corrections(self):
st = UTCDateTime(df['date'][i]).timestamp
et = st + 24*3600

if(type(self.correction_map_tree[net][sta]) != index.Index):
self.correction_map_tree[net][sta] = index.Index()
self.correction_map_bounds[net][sta] = []
self.correction_map_values[net][sta] = []
if(type(self.correction_map_tree[net][sta][loc]) != index.Index):
self.correction_map_tree[net][sta][loc] = index.Index()
self.correction_map_bounds[net][sta][loc] = []
self.correction_map_values[net][sta][loc] = []
# end if

self.correction_map_tree[net][sta].insert(i, (st, 1, et, 1))
self.correction_map_bounds[net][sta].append([st, et])
self.correction_map_values[net][sta].append(corr)
self.correction_map_tree[net][sta][loc].insert(corr_count, (st, 1, et, 1))
self.correction_map_bounds[net][sta][loc].append([st, et])
self.correction_map_values[net][sta][loc].append(corr)
corr_count += 1
# end for
except:
raise ValueError('Failed to read corrections file {}..'.format(fname))
except Exception as e:
print ('Warning: failed to read corrections file {} with error({}). '
'Continuing along..'.format(fname, traceback.format_exc()))
#end try
# end for
#end func

def _get_correction(self, net, sta, st, et):
tindex = self.correction_map_tree[net][sta]
def _get_correction(self, net, sta, loc, st, et):
tindex = self.correction_map_tree[net][sta][loc]
if(type(tindex) != index.Index):
return None
else:
epsilon = 1e-5
indices = list(tindex.intersection((st.timestamp+epsilon, 1, et.timestamp-epsilon, 1)))

if(len(indices)):
if(len(indices) > 1):
raise ValueError('Error encountered in _get_correction. Aborting..')
if(len(indices) > 1):
print('Warning: multivalued corrections found ({}.{}.{}: {} - {}). '
'Ignoring and moving along..'.format(net, sta, loc, st, et))
return None
# end if

cst, cet = self.correction_map_bounds[net][sta][indices[0]]
cst, cet = self.correction_map_bounds[net][sta][loc][indices[0]]
a = np.fmax(st.timestamp, cst)
b = np.fmin(et.timestamp, cet)

if(a > b):
raise ValueError('Error encountered in _get_correction. Aborting..')
if(a > b): # sanity check
raise ValueError('Error encountered in _get_correction (({}.{}.{}: {} - {})). '
'Aborting..'.format(net, sta, loc, st, et))
# end if

# return overlap and correction
return [UTCDateTime(a), UTCDateTime(b)], self.correction_map_values[net][sta][indices[0]]
return [UTCDateTime(a), UTCDateTime(b)], self.correction_map_values[net][sta][loc][indices[0]]
else:
return None
# end if
Expand Down Expand Up @@ -258,6 +262,7 @@ def day_split(trc):
for tr in dayStream:
net = tr.stats.network
sta = tr.stats.station
loc = tr.stats.location
st = tr.stats.starttime
et = tr.stats.endtime

Expand All @@ -266,7 +271,7 @@ def day_split(trc):
continue
# end if

result = self._get_correction(net, sta, st, et)
result = self._get_correction(net, sta, loc, st, et)

if(result):
ost, oet = result[0]
Expand Down Expand Up @@ -467,10 +472,8 @@ def get_stations(self, starttime, endtime, network=None, station=None, location=
for row in rows:
ds_id, net, sta, loc, cha, st, et, tag = row

rv = (net, sta, loc, cha,
self.asdf_station_coordinates[ds_id]['%s.%s' % (net, sta)][0],
self.asdf_station_coordinates[ds_id]['%s.%s' % (net, sta)][1],
self.asdf_station_coordinates[ds_id]['%s.%s' % (net, sta)][2])
# [net, sta, loc, cha, lon, lat, elev_m]
rv = (net, sta, loc, cha, *self.asdf_station_coordinates[ds_id]['%s.%s' % (net, sta)])
results.add(rv)
# end for

Expand Down Expand Up @@ -596,19 +599,25 @@ def stations_iterator(self, network_list=[], station_list=[]):
for sta in stas:
sta = sta[0]

tbounds = self.conn.execute("select st, et from wdb where net='%s' and sta='%s' order by et"
# trace-count, min(st), max(et)
attribs = self.conn.execute("select count(st), min(st), max(et) from wdb where net='%s' and sta='%s'"
%(net, sta)).fetchall()

if(len(tbounds)==0): continue
#tbounds = np.array(tbounds)
#tbounds = split_list_by_timespan(tbounds, self.nproc)
tbounds = split_list(tbounds, self.nproc)
if(len(attribs)==0): continue
tcount, min_st, max_et = np.array(attribs).flatten()

# create start and end times for each rank
r = np.linspace(min_st, max_et, self.nproc + 1)
rank_spans = np.vstack([r[0:-1], r[1:]]).T

# reproducibly shuffle rank-spans to balance load across ranks
np.random.seed(int(tcount))
rank_spans = split_list(rank_spans, self.nproc)
np.random.shuffle(rank_spans)

for iproc in np.arange(self.nproc):
if (len(tbounds[iproc])):
arr = np.array(tbounds[iproc])
workload[iproc][net][sta] = np.array([np.min(arr[:,0]),
np.max(arr[:,1])])
if (len(rank_spans[iproc])):
workload[iproc][net][sta] = rank_spans[iproc].flatten()
# end for
# end for
# end for
Expand Down

0 comments on commit 1bbc5be

Please sign in to comment.