Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete DL files on unsubscribe. #16182

Merged
merged 15 commits into from
Sep 6, 2023
21 changes: 20 additions & 1 deletion chia/data_layer/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
)
from chia.data_layer.data_layer_wallet import DataLayerWallet, Mirror, SingletonRecord, verify_offer
from chia.data_layer.data_store import DataStore
from chia.data_layer.download_data import insert_from_delta_file, write_files_for_root
from chia.data_layer.download_data import (
get_delta_filename,
get_full_tree_filename,
insert_from_delta_file,
write_files_for_root,
)
from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.server.outbound_message import NodeType
Expand Down Expand Up @@ -527,10 +532,24 @@ async def unsubscribe(self, tree_id: bytes32) -> None:
subscriptions = await self.get_subscriptions()
if tree_id not in (subscription.tree_id for subscription in subscriptions):
raise RuntimeError("No subscription found for the given tree_id.")
filenames: List[str] = []
if await self.data_store.tree_id_exists(tree_id):
generation = await self.data_store.get_tree_generation(tree_id)
all_roots = await self.data_store.get_roots_between(tree_id, 1, generation + 1)
for root in all_roots:
root_hash = root.node_hash if root.node_hash is not None else self.none_bytes
filenames.append(get_full_tree_filename(tree_id, root_hash, root.generation))
filenames.append(get_delta_filename(tree_id, root_hash, root.generation))
async with self.subscription_lock:
await self.data_store.unsubscribe(tree_id)
await self.wallet_rpc.dl_stop_tracking(tree_id)
self.log.info(f"Unsubscribed to {tree_id}")
for filename in filenames:
file_path = self.server_files_location.joinpath(filename)
try:
file_path.unlink()
except FileNotFoundError:
pass

async def get_subscriptions(self) -> List[Subscription]:
async with self.subscription_lock:
Expand Down
51 changes: 50 additions & 1 deletion tests/core/data_layer/test_data_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from chia.data_layer.data_layer_errors import OfferIntegrityError
from chia.data_layer.data_layer_util import OfferStore, Status, StoreProofs
from chia.data_layer.data_layer_wallet import DataLayerWallet, verify_offer
from chia.data_layer.download_data import get_delta_filename, get_full_tree_filename
from chia.rpc.data_layer_rpc_api import DataLayerRpcApi
from chia.rpc.data_layer_rpc_client import DataLayerRpcClient
from chia.rpc.wallet_rpc_api import WalletRpcApi
Expand Down Expand Up @@ -65,6 +66,7 @@ async def init_data_layer_service(
bt: BlockTools,
db_path: Path,
wallet_service: Optional[Service[WalletNode, WalletNodeAPI]] = None,
manage_data_interval: int = 5,
) -> AsyncIterator[Service[DataLayer, DataLayerAPI]]:
config = bt.config
config["data_layer"]["wallet_peer"]["port"] = int(wallet_rpc_port)
Expand All @@ -73,6 +75,7 @@ async def init_data_layer_service(
config["data_layer"]["port"] = 0
config["data_layer"]["rpc_port"] = 0
config["data_layer"]["database_path"] = str(db_path.joinpath("db.sqlite"))
config["data_layer"]["manage_data_interval"] = manage_data_interval
save_config(bt.root_path, "config.yaml", config)
service = create_data_layer_service(
root_path=bt.root_path, config=config, wallet_service=wallet_service, downloaders=[], uploaders=[]
Expand All @@ -91,8 +94,11 @@ async def init_data_layer(
bt: BlockTools,
db_path: Path,
wallet_service: Optional[Service[WalletNode, WalletNodeAPI]] = None,
manage_data_interval: int = 5,
) -> AsyncIterator[DataLayer]:
async with init_data_layer_service(wallet_rpc_port, bt, db_path, wallet_service) as data_layer_service:
async with init_data_layer_service(
wallet_rpc_port, bt, db_path, wallet_service, manage_data_interval
) as data_layer_service:
yield data_layer_service._api.data_layer


Expand Down Expand Up @@ -2016,3 +2022,46 @@ async def test_issue_15955_deadlock(
await asyncio.gather(
*(asyncio.create_task(data_layer.get_value(store_id=tree_id, key=key)) for _ in range(10))
)


@pytest.mark.asyncio
async def test_unsubscribe_removes_files(
self_hostname: str,
one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices,
tmp_path: Path,
) -> None:
wallet_rpc_api, full_node_api, wallet_rpc_port, ph, bt = await init_wallet_and_node(
self_hostname, one_wallet_and_one_simulator_services
)
manage_data_interval = 5
async with init_data_layer(
wallet_rpc_port=wallet_rpc_port, bt=bt, db_path=tmp_path, manage_data_interval=manage_data_interval
) as data_layer:
data_rpc_api = DataLayerRpcApi(data_layer)
res = await data_rpc_api.create_data_store({})
root_hashes: List[bytes32] = []
assert res is not None
store_id = bytes32.from_hexstr(res["id"])
await farm_block_check_singelton(data_layer, full_node_api, ph, store_id)

update_count = 10
for batch_count in range(update_count):
key = batch_count.to_bytes(2, "big")
value = batch_count.to_bytes(2, "big")
changelist = [{"action": "insert", "key": key.hex(), "value": value.hex()}]
res = await data_rpc_api.batch_update({"id": store_id.hex(), "changelist": changelist})
update_tx_rec = res["tx_id"]
await farm_block_with_spend(full_node_api, ph, update_tx_rec, wallet_rpc_api)
await asyncio.sleep(manage_data_interval * 2)
root_hash = await data_rpc_api.get_root({"id": store_id.hex()})
root_hashes.append(root_hash["hash"])

filenames = {path.name for path in data_layer.server_files_location.iterdir()}
assert len(filenames) == 2 * update_count
for generation, hash in enumerate(root_hashes):
assert get_delta_filename(store_id, hash, generation + 1) in filenames
assert get_full_tree_filename(store_id, hash, generation + 1) in filenames

res = await data_rpc_api.unsubscribe(request={"id": store_id.hex()})
filenames = {path.name for path in data_layer.server_files_location.iterdir()}
assert len(filenames) == 0
Loading