-
Notifications
You must be signed in to change notification settings - Fork 2k
/
check_wallet_db.py
418 lines (343 loc) · 16 KB
/
check_wallet_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
from __future__ import annotations
import asyncio
import sys
from collections import defaultdict
from pathlib import Path
from sqlite3 import Row
from typing import Any, Dict, Iterable, List, Optional, Set
from chia.util.collection import find_duplicates
from chia.util.db_synchronous import db_synchronous_on
from chia.util.db_wrapper import DBWrapper2, execute_fetchone
from chia.util.pprint import print_compact_ranges
from chia.wallet.util.wallet_types import WalletType
# TODO: Check for missing paired wallets (eg. No DID wallet for an NFT)
# TODO: Check for missing DID Wallets
help_text = """
\b
The purpose of this command is find potential issues in Chia wallet databases.
The core chia client currently uses sqlite to store the wallet databases, one database per key.
\b
Guide to warning diagnostics:
----------------------------
"Missing Wallet IDs": A wallet was created and later deleted. By itself, this is okay because
the wallet does not reuse wallet IDs. However, this information may be useful
in conjunction with other information.
\b
Guide to error diagnostics:
--------------------------
Diagnostics in the error section indicate an error in the database structure.
In general, this does not indicate an error in on-chain data, nor does it mean that you have lost coins.
\b
An example is "Missing DerivationPath indexes" - a derivation path is a sub-key of your master key. Missing
derivation paths could cause your wallet to not "know" about transactions that happened on the blockchain.
\b
"""
def _validate_args_addresses_used(wallet_id: int, last_index: int, last_hardened: int, dp: DerivationPath) -> None:
if last_hardened:
if last_hardened != dp.hardened:
raise ValueError(f"Invalid argument: Mix of hardened and unhardened columns wallet_id={wallet_id}")
if last_index:
if last_index != dp.derivation_index:
raise ValueError(f"Invalid argument: noncontiguous derivation_index at {last_index} wallet_id={wallet_id}")
def check_addresses_used_contiguous(derivation_paths: List[DerivationPath]) -> List[str]:
"""
The used column for addresses in the derivation_paths table should be a
zero or greater run of 1's, followed by a zero or greater run of 0's.
There should be no used derivations after seeing a used derivation.
"""
errors: List[str] = []
for wallet_id, dps in dp_by_wallet_id(derivation_paths).items():
saw_unused = False
bad_used_values: Set[int] = set()
ordering_errors: List[str] = []
# last_index = None
# last_hardened = None
for dp in dps:
# _validate_args_addresses_used(wallet_id, last_index, last_hardened, dp)
if saw_unused and dp.used == 1 and ordering_errors == []:
ordering_errors.append(
f"Wallet {dp.wallet_id}: "
f"Used address after unused address at derivation index {dp.derivation_index}"
)
if dp.used == 1:
pass
elif dp.used == 0:
saw_unused = True
else:
bad_used_values.add(dp.used)
# last_hardened = dp.hardened
# last_index = dp.derivation_index
if len(bad_used_values) > 0:
errors.append(f"Wallet {wallet_id}: Bad values in 'used' column: {bad_used_values}")
if ordering_errors != []:
errors.extend(ordering_errors)
return errors
def check_for_gaps(array: List[int], start: int, end: int, *, data_type_plural: str = "Elements") -> List[str]:
"""
Check for compact sequence:
Check that every value from start to end is present in array, and no more.
start and end are values, not indexes
start and end should be included in array
array can be unsorted
"""
if start > end:
raise ValueError(f"{__name__} called with incorrect arguments: start={start} end={end} (start > end)")
errors: List[str] = []
if start == end and len(array) == 1:
return errors
expected_set = set(range(start, end + 1))
actual_set = set(array)
missing = expected_set.difference(actual_set)
extras = actual_set.difference(expected_set)
duplicates = find_duplicates(array)
if len(missing) > 0:
errors.append(f"Missing {data_type_plural}: {print_compact_ranges(list(missing))}")
if len(extras) > 0:
errors.append(f"Unexpected {data_type_plural}: {extras}")
if len(duplicates) > 0:
errors.append(f"Duplicate {data_type_plural}: {duplicates}")
return errors
class FromDB:
def __init__(self, row: Iterable[Any], fields: List[str]) -> None:
self.fields = fields
for field, value in zip(fields, row):
setattr(self, field, value)
def __repr__(self) -> str:
s = ""
for f in self.fields:
s += f"{f}={getattr(self, f)} "
return s
def wallet_type_name(
wallet_type: int,
) -> str:
if wallet_type in {wt.value for wt in WalletType}:
return f"{WalletType(wallet_type).name} ({wallet_type})"
else:
return f"INVALID_WALLET_TYPE ({wallet_type})"
def _cwr(row: Row) -> List[Any]:
r = []
for i, v in enumerate(row):
if i == 2:
r.append(wallet_type_name(v))
else:
r.append(v)
return r
# wallet_types_that_dont_need_derivations: See require_derivation_paths for each wallet type
wallet_types_that_dont_need_derivations = {WalletType.POOLING_WALLET, WalletType.NFT}
class DerivationPath(FromDB):
derivation_index: int
pubkey: str
puzzle_hash: str
wallet_type: WalletType
wallet_id: int
used: int # 1 or 0
hardened: int # 1 or 0
class Wallet(FromDB):
id: int # id >= 1
name: str
wallet_type: WalletType
data: str
def dp_by_wallet_id(derivation_paths: List[DerivationPath]) -> Dict[int, List[DerivationPath]]:
d = defaultdict(list)
for derivation_path in derivation_paths:
d[derivation_path.wallet_id].append(derivation_path)
for k, v in d.items():
d[k] = sorted(v, key=lambda dp: dp.derivation_index)
return d
def derivation_indices_by_wallet_id(derivation_paths: List[DerivationPath]) -> Dict[int, List[int]]:
d = dp_by_wallet_id(derivation_paths)
di = {}
for k, v in d.items():
di[k] = [dp.derivation_index for dp in v]
return di
def print_min_max_derivation_for_wallets(derivation_paths: List[DerivationPath]) -> None:
d = derivation_indices_by_wallet_id(derivation_paths)
print("Min, Max, Count of derivations for each wallet:")
for wallet_id, derivation_index_list in d.items():
# TODO: Fix count by separating hardened and unhardened
print(
f"Wallet ID {wallet_id:2} derivation index min: {derivation_index_list[0]} "
f"max: {derivation_index_list[-1]} count: {len(derivation_index_list)}"
)
class WalletDBReader:
db_wrapper: DBWrapper2 # TODO: Remove db_wrapper member
config = {"db_readers": 1}
sql_log_path: Optional[Path] = None
verbose = False
async def get_all_wallets(self) -> List[Wallet]:
wallet_fields = ["id", "name", "wallet_type", "data"]
async with self.db_wrapper.reader_no_transaction() as reader:
# TODO: if table doesn't exist
cursor = await reader.execute(f"""SELECT {", ".join(wallet_fields)} FROM users_wallets""")
rows = await cursor.fetchall()
return [Wallet(r, wallet_fields) for r in rows]
async def get_derivation_paths(self) -> List[DerivationPath]:
fields = ["derivation_index", "pubkey", "puzzle_hash", "wallet_type", "wallet_id", "used", "hardened"]
async with self.db_wrapper.reader_no_transaction() as reader:
# TODO: if table doesn't exist
cursor = await reader.execute(f"""SELECT {", ".join(fields)} FROM derivation_paths;""")
rows = await cursor.fetchall()
return [DerivationPath(row, fields) for row in rows]
async def show_tables(self) -> List[str]:
async with self.db_wrapper.reader_no_transaction() as reader:
cursor = await reader.execute("""SELECT name FROM sqlite_master WHERE type='table';""")
print("\nWallet DB Tables:")
print(*([r[0] for r in await cursor.fetchall()]), sep=",\n")
print("\nWallet Schema:")
print(*(await (await cursor.execute("PRAGMA table_info('users_wallets')")).fetchall()), sep=",\n")
print("\nDerivationPath Schema:")
print(*(await (await cursor.execute("PRAGMA table_info('derivation_paths')")).fetchall()), sep=",\n")
print()
return []
async def check_wallets(self) -> List[str]:
# id, name, wallet_type, data
# TODO: Move this SQL up a level
async with self.db_wrapper.reader_no_transaction() as reader:
errors = []
try:
main_wallet_id = 1
main_wallet_type = WalletType.STANDARD_WALLET
row = await execute_fetchone(reader, "SELECT * FROM users_wallets WHERE id=?", (main_wallet_id,))
if row is None:
errors.append(f"There is no wallet with ID {main_wallet_id} in table users_wallets")
elif row[2] != main_wallet_type:
errors.append(
f"We expect wallet {main_wallet_id} to have type {wallet_type_name(main_wallet_type)}, "
f"but it has {wallet_type_name(row[2])}"
)
except Exception as e:
errors.append(f"Exception while trying to access wallet {main_wallet_id} from users_wallets: {e}")
max_id: Optional[int] = None
max_id_row = await execute_fetchone(reader, "SELECT MAX(id) FROM users_wallets")
if max_id_row is None:
errors.append("Error fetching max wallet ID from table users_wallets. No wallets ?!?")
else:
cursor = await reader.execute("""SELECT * FROM users_wallets""")
rows = await cursor.fetchall()
max_id = max_id_row[0]
assert max_id is not None
errors.extend(check_for_gaps([r[0] for r in rows], main_wallet_id, max_id, data_type_plural="Wallet IDs"))
if self.verbose:
print("\nWallets:")
print(*[_cwr(r) for r in rows], sep=",\n")
# Check for invalid wallet types in users_wallets
invalid_wallet_types = set()
for row in rows:
if row[2] not in {wt.value for wt in WalletType}:
invalid_wallet_types.add(row[2])
if len(invalid_wallet_types) > 0:
errors.append(f"Invalid Wallet Types found in table users_wallets: {invalid_wallet_types}")
return errors
def check_wallets_missing_derivations(
self, wallets: List[Wallet], derivation_paths: List[DerivationPath]
) -> List[str]:
p = []
d = derivation_indices_by_wallet_id(derivation_paths) # TODO: calc this once, pass in
for w in wallets:
if w.wallet_type not in wallet_types_that_dont_need_derivations and w.id not in d:
p.append(w.id)
if len(p) > 0:
return [f"Wallet IDs with no derivations that require them: {p}"]
return []
def check_derivations_are_compact(self, wallets: List[Wallet], derivation_paths: List[DerivationPath]) -> List[str]:
errors = []
"""
Gaps in derivation index
Missing hardened or unhardened derivations
TODO: Gaps in used derivations
"""
for wallet_id in [w.id for w in wallets]:
for hardened in [0, 1]:
dps = list(filter(lambda x: x.wallet_id == wallet_id and x.hardened == hardened, derivation_paths))
if len(dps) < 1:
continue
dpi = [x.derivation_index for x in dps]
dpi.sort()
max_id = dpi[-1]
h = [" hardened", "unhardened"][hardened]
errors.extend(
check_for_gaps(
dpi, 0, max_id, data_type_plural=f"DerivationPath indexes for {h} wallet_id={wallet_id}"
)
)
return errors
def check_unexpected_derivation_entries(
self, wallets: List[Wallet], derivation_paths: List[DerivationPath]
) -> List[str]:
"""
Check for unexpected derivation path entries
Invalid Wallet Type
Wallet IDs not in table 'users_wallets'
Wallet ID with different wallet_type
"""
errors = []
wallet_id_to_type = {w.id: w.wallet_type for w in wallets}
invalid_wallet_types = []
missing_wallet_ids = []
wrong_type = defaultdict(list)
for d in derivation_paths:
if d.wallet_type not in {wt.value for wt in WalletType}:
invalid_wallet_types.append(d.wallet_type)
if d.wallet_id not in wallet_id_to_type:
missing_wallet_ids.append(d.wallet_id)
elif d.wallet_type != wallet_id_to_type[d.wallet_id]:
wrong_type[(d.hardened, d.wallet_id, d.wallet_type, wallet_id_to_type[d.wallet_id])].append(
d.derivation_index
)
if len(invalid_wallet_types) > 0:
errors.append(f"Invalid wallet_types in derivation_paths table: {invalid_wallet_types}")
if len(missing_wallet_ids) > 0:
errors.append(
f"Wallet IDs found in derivation_paths table, but not in users_wallets table: {missing_wallet_ids}"
)
for k, v in wrong_type.items():
errors.append(
f"""{[" ", "un"][int(k[0])]}hardened Wallet ID {k[1]} uses type {wallet_type_name(k[2])} in """
f"derivation_paths, but type {wallet_type_name(k[3])} in wallet table at these derivation indices: {v}"
)
return errors
async def scan(self, db_path: Path) -> int:
"""Returns number of lines of error output (not warnings)"""
async with DBWrapper2.managed(
database=db_path,
reader_count=self.config.get("db_readers", 4),
log_path=self.sql_log_path,
synchronous=db_synchronous_on("auto"),
) as self.db_wrapper:
# TODO: Pass down db_wrapper
wallets = await self.get_all_wallets()
derivation_paths = await self.get_derivation_paths()
errors = []
warnings = []
if self.verbose:
await self.show_tables()
print_min_max_derivation_for_wallets(derivation_paths)
warnings.extend(await self.check_wallets())
errors.extend(self.check_wallets_missing_derivations(wallets, derivation_paths))
errors.extend(self.check_unexpected_derivation_entries(wallets, derivation_paths))
errors.extend(self.check_derivations_are_compact(wallets, derivation_paths))
errors.extend(check_addresses_used_contiguous(derivation_paths))
if len(warnings) > 0:
print(f" ---- Warnings Found for {db_path.name} ----")
print("\n".join(warnings))
if len(errors) > 0:
print(f" ---- Errors Found for {db_path.name}----")
print("\n".join(errors))
return len(errors)
async def scan(root_path: str, db_path: Optional[str] = None, *, verbose: bool = False) -> None:
if db_path is None:
wallet_db_path = Path(root_path) / "wallet" / "db"
wallet_db_paths = list(wallet_db_path.glob("blockchain_wallet_*.sqlite"))
else:
wallet_db_paths = [Path(db_path)]
num_errors = 0
for wallet_db_path in wallet_db_paths:
w = WalletDBReader()
w.verbose = verbose
print(f"Reading {wallet_db_path}")
num_errors += await w.scan(Path(wallet_db_path))
if num_errors > 0:
sys.exit(2)
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(scan("", sys.argv[1]))