## Imports & Dependencies

This cell imports all required libraries and internal helpers used throughout the notebook.
The core logic lives in the `aave_umbrella` package to avoid duplication inside the notebook.

The notebook reuses **the same code paths as the project’s test suite**, with additional
print statements and formatted outputs to make on-chain state transitions
and protocol interactions easier to inspect visually.

In [None]:
from datetime import datetime, timezone
from decimal import Decimal
from typing import Any, Dict, List, Tuple

import matplotlib.pyplot as plt
import pandas as pd

from aave_umbrella.actions.deposit import deposit
from aave_umbrella.actions.harvest import calculate_current_user_rewards, claim_all_rewards
from aave_umbrella.actions.withdraw import redeem
from aave_umbrella.config.addresses import USDC_ADDRESS, USDC_UMBRELLA_STAKE_TOKEN
from aave_umbrella.contracts.batch_helper import BatchHelper, IOData
from aave_umbrella.contracts.erc20 import ERC20
from aave_umbrella.contracts.stake_token import StakeToken
from aave_umbrella.forks.account import get_user_account
from aave_umbrella.forks.funding import fund_user
from aave_umbrella.providers.web3_client import AsyncW3, build_web3_connection
from aave_umbrella.utils.math import amount_to_small_units
from tests.helpers.helper import back_to_the_future

## Helper Functions

This cell defines utility functions used by multiple actions:
- Unit conversion helpers
- Time manipulation on the forked chain
- Common printing and state-inspection helpers

These functions are pure helpers and do not perform protocol interactions directly.

In [None]:
def base_units_to_amount(amount: int, decimals: int) -> Decimal:
    return Decimal(amount) / (Decimal(10) ** decimals)


def fmt_token(amount_base: int, decimals: int, symbol: str) -> str:
    human = base_units_to_amount(int(amount_base), decimals)
    return f"{human:,.2f} {symbol} ({int(amount_base):,} base units)"


def timestamp_to_utc_string(timestamp: int) -> str:
    """
    Convert a UNIX timestamp (seconds) to a human-readable UTC date string.

    Example:
        1700000000 -> "2023-11-14 22:13:20 UTC"
    """
    return datetime.fromtimestamp(timestamp, tz=timezone.utc).strftime("%Y-%m")


async def snapshot_balances(
    *,
    web3,
    user_address: str,
    tokens: List[Tuple[str, str, int]],  # (label, token_address, decimals)
) -> Dict[str, int]:
    """
    Prints and returns balances for a list of ERC20s.
    """
    out: Dict[str, int] = {}
    print("\n== Snapshot: balances ==")
    print(f"User: {user_address}")

    for label, token_address, decimals in tokens:
        token = ERC20(web3, token_address)
        bal = await token.balance_of(wallet_address=user_address)
        out[label] = int(bal)
        print(f"{label}: {fmt_token(bal, decimals, label)}")

    return out


async def snapshot_shares(
    *,
    web3,
    user_address: str,
    stake_token_address: str,
    label: str,
) -> Dict[str, int]:
    stake_token = ERC20(web3, stake_token_address)
    shares = await stake_token.balance_of(wallet_address=user_address)
    print(f"\n== Snapshot: {label} ==")
    print(f"Umbrella shares (raw): {int(shares):,}")
    return {"shares": int(shares)}


async def latest_block_timestamp(web3: AsyncW3) -> int:
    """
    Works with most async web3 clients.
    """
    block = await web3.eth.get_block("latest")
    return int(block["timestamp"])


def rewards_to_df(
    rows: List[Dict[str, Any]],
    *,
    token_decimals: Dict[str, int] | None = None,
) -> pd.DataFrame:
    """
    Rows items expected keys: month, timestamp, token, amount_base_units
    """
    df = pd.DataFrame(rows)
    if df.empty:
        return df

    if token_decimals:
        df["amount"] = df.apply(
            lambda r: float(base_units_to_amount(r["amount_base_units"], token_decimals.get(r["token"], 18))),
            axis=1,
        )
    return df


def plot_rewards(df: pd.DataFrame, title: str = "Rewards accrued over time") -> None:
    """
    Expects df with columns: timestamp, token, amount_base_units and optionally amount.
    """
    if df.empty:
        print("No data to plot.")
        return

    y_col = "amount" if "amount" in df.columns else "amount_base_units"

    fig, ax = plt.subplots()
    for token, g in df.groupby("token"):
        g_sorted = g.sort_values("timestamp")
        ax.plot(g_sorted["timestamp"], g_sorted[y_col], label=token)

    ax.set_title(title)
    ax.set_xlabel("timestamp")
    ax.set_ylabel(y_col)
    ax.legend()
    plt.show()

## Web3 Connection & Global Configuration

This cell:
- Initializes the `AsyncWeb3` instance
- Defines all constant addresses and configuration values used in subsequent steps

All contracts instantiated later reuse this single Web3 connection.

In [None]:
web3 = await build_web3_connection(is_notebook=True)
user_account = await get_user_account(web3)

whale_address = "0x2d4fbc5ee56f063d33e9c6390265eeac97afcda8"
edge_token_to_fund_amount = 200_000

# Deposit
stake_token_address = USDC_UMBRELLA_STAKE_TOKEN  # Umbrella StakeToken
edge_token_address = USDC_ADDRESS  # Token to deposit
edge_token_decimals = 6
edge_token_symbol = "USDC"

stake_token_contract = ERC20(web3, stake_token_address)
edge_token_contract = ERC20(web3, edge_token_address)

deposit_amount = 10_000  # Amount to deposit (here 10,000 USDC)

# Harvest
stake_token_checksum = web3.to_checksum_address(stake_token_address)
user_checksum = user_account.address

# Withdraw
stake_token_vault_contract = StakeToken(web3, stake_token_address)
batch_helper_contract = BatchHelper(web3)


print("Test user address:", user_account.address)

## Funding Test Account

This step funds the test user with USDC using whale impersonation on the forked mainnet.
This simulates a realistic user balance without relying on faucets or mocks.

In [None]:
current_balance = await edge_token_contract.balance_of(user_account.address)
print("\n== Funding: Before ==")
print(f"{edge_token_symbol}:", fmt_token(current_balance, edge_token_decimals, edge_token_symbol))

token_to_wallets = {
    USDC_ADDRESS: (
        whale_address,  # whale
        edge_token_to_fund_amount,  # human units expected by your fund_user
    ),
}

print("\n== Funding: Executing ==")
success = await fund_user(web3, user_account, token_to_wallets)
assert success is True

final_balance = await edge_token_contract.balance_of(user_account.address)
print("\n== Funding: After ==")
print(f"{edge_token_symbol}:", fmt_token(final_balance, edge_token_decimals, edge_token_symbol))

## Deposit — USDC → Umbrella StakeToken

This step:
- Approves the Batch Helper contract to spend the user’s USDC
- Deposits USDC into the Umbrella StakeToken (ERC-4626 vault) via the Batch Helper
- Displays USDC balances and minted StakeToken shares before and after the deposit

In [None]:
deposit_amount_base_units = amount_to_small_units(deposit_amount, edge_token_decimals)

before_usdc = await edge_token_contract.balance_of(wallet_address=user_account.address)
before_shares = await stake_token_contract.balance_of(wallet_address=user_account.address)

print("\n== Deposit: Before ==")
print(f"Wallet {edge_token_symbol}:", fmt_token(before_usdc, edge_token_decimals, edge_token_symbol))
print("Umbrella shares (raw):", f"{int(before_shares):,}")

print("\n== Deposit: Executing ==")
print("Deposit amount:", fmt_token(deposit_amount_base_units, edge_token_decimals, edge_token_symbol))

deposit_status = await deposit(
    web3=web3,
    user_account=user_account,
    params=IOData(
        stake_token=web3.to_checksum_address(stake_token_address),
        edge_token=web3.to_checksum_address(edge_token_address),
        value=int(deposit_amount_base_units),
    ),
)
assert deposit_status is True

after_usdc = await edge_token_contract.balance_of(wallet_address=user_account.address)
after_shares = await stake_token_contract.balance_of(wallet_address=user_account.address)

print("\n== Deposit: After ==")
print(f"Wallet {edge_token_symbol}:", fmt_token(after_usdc, edge_token_decimals, edge_token_symbol))
print("Umbrella shares (raw):", f"{int(after_shares):,}")

assert int(after_shares) > int(before_shares)

## Harvest (Read) — Accrued Rewards State

This step:
- Advances blockchain time to simulate reward accrual
- Reads on-chain reward state without mutating protocol state
- Displays pending reward amounts per reward token for the user

This demonstrates how rewards accrue over time on forked mainnet state. If the current block timestamp exceeds `distributionEnd`, rewards stop accruing.

In [None]:
current_rewards = await calculate_current_user_rewards(
    web3=web3,
    stake_token=stake_token_checksum,
    user_address=user_checksum,
)

print("\n== Harvest: Baseline rewards ==")
for address, amount in zip(current_rewards[0], current_rewards[1]):
    print(f"Token {address}: {int(amount):,} base units")

# --- Monthly accrual loop
months = 12  # change to 24 for 2 years
rows: List[Dict[str, Any]] = []

start_ts = await latest_block_timestamp(web3)
seconds_per_month = 30 * 24 * 60 * 60  # approx month

print(f"\n== Harvest: Accrual loop ({months} months) ==")
print("Start timestamp:", timestamp_to_utc_string(start_ts))

for m in range(1, months + 1):
    target_ts = start_ts + m * seconds_per_month
    await back_to_the_future(web3, target_ts)

    reward_tokens, reward_amounts = await calculate_current_user_rewards(
        web3=web3,
        stake_token=stake_token_checksum,
        user_address=user_checksum,
    )

    # record per token
    for token, amount in zip(reward_tokens, reward_amounts):
        rows.append(
            {
                "month": m,
                "timestamp": timestamp_to_utc_string(target_ts),
                "token": token,
                "amount_base_units": int(amount),
            }
        )

    # small console summary per month (minimal noise)
    total_small = sum(int(a) for a in reward_amounts)
    # print(f"Month {m:02d}: total rewards (raw sum) = {total_small:,}")

df_rewards = rewards_to_df(rows)
display(df_rewards.head(10))
plot_rewards(df_rewards, title="Rewards accrued (raw base units)")

## Harvest (Claim) — Claim Accrued Rewards

This step:
- Claims accrued rewards via the Umbrella Rewards Controller
- Transfers reward tokens to the user
- Displays reward token balances before and after claiming

In [None]:
print("\n== Harvest: Claiming all rewards at end of loop ==")
is_success = await claim_all_rewards(
    web3,
    user_account,
    stake_token=stake_token_checksum,
    receiver=user_checksum,
)
assert is_success is True

# Verify balances received >= last computed amounts for each token (from the latest month)
last_month = df_rewards["month"].max()
last = df_rewards[df_rewards["month"] == last_month]

print("\n== Harvest: Post-claim verification ==")
for token_addr, g in last.groupby("token"):
    expected_amount = int(g["amount_base_units"].max())
    reward_contract = ERC20(web3, token_addr)
    reward_balance = await reward_contract.balance_of(user_checksum)
    print(f"Token reward {token_addr}")
    print(f"Expected at least={expected_amount:,} base units")
    print(f"User current balance={expected_amount:,} base units")

# Ensure rewards reset to 0 after claim
rewards_after_claim = await calculate_current_user_rewards(
    web3=web3,
    stake_token=stake_token_checksum,
    user_address=user_checksum,
)

print("\n== Harvest: Rewards available after claim ==")
for address, amount in zip(rewards_after_claim[0], rewards_after_claim[1]):
    print(f"Token {address}: {int(amount):,} base units")

## Withdraw — Redeem StakeToken for USDC

This step:
- Calls `cooldown()` on the StakeToken to initiate the mandatory cooldown period
- Advances blockchain time to satisfy the cooldown requirement
- Approves the Batch Helper contract to spend the user’s StakeToken shares
- Redeems StakeToken shares back into the underlying USDC via the Batch Helper
- Displays final balances to validate the full lifecycle of a position

In [None]:
shares_balance = await stake_token_contract.balance_of(stake_token_address)

print("\n== Withdraw: Before cooldown ==")
before_cooldown = await stake_token_vault_contract.get_staker_cooldown(user_address=user_account.address)
print("Cooldown info:", before_cooldown)

print("\n== Withdraw: Initiate cooldown ==")
await stake_token_vault_contract.cooldown(user_account)

cooldown_expected_shares, end_of_cooldown, _ = await stake_token_vault_contract.get_staker_cooldown(
    user_address=user_account.address
)
print("Cooldown shares:", f"{int(cooldown_expected_shares):,}")
print("End of cooldown timestamp:", timestamp_to_utc_string(end_of_cooldown))

print("\n== Withdraw: Time travel to end of cooldown ==")
await back_to_the_future(web3, int(end_of_cooldown))

print("\n== Withdraw: Redeem ==")
is_redeem_success = await redeem(
    web3,
    user_account,
    redeem_params=IOData(
        stake_token=web3.to_checksum_address(stake_token_address),
        edge_token=web3.to_checksum_address(edge_token_address),
        value=cooldown_expected_shares,
    ),
)

stake_balance_after = await stake_token_contract.balance_of(wallet_address=user_account.address)

edge_balance_after = await edge_token_contract.balance_of(wallet_address=user_account.address)
edge_decimals = await edge_token_contract.decimals()

print("\n== Withdraw: After redeem ==")
print("Stake token balance:", f"{int(stake_balance_after):,}")
print(f"Wallet {edge_token_symbol}:", fmt_token(edge_balance_after, int(edge_decimals), edge_token_symbol))