# Cosmos DB in Fabric #
# Credit Card Fraud Detection Sample – Part 3: Fraud Detection via Cosmos Change Feed in Spark #
This section demonstrates how fraud can be detected through streaming transaction inputs. In the previous notebook we added pending transactions into the Pending Transaction Container in our Cosmos Database. Here we will read the changefeed and process the pending transactions to check if the transaction is fraudulent based on the vector distance of the transaction embedding compared to the transaction history of the card. If it is fraudulent we update the Credit Card Container to lock the card. If it is not fraudulent then we finally add it to the Transaction Container.

### Prerequisites
Before running this notebook, ensure you have:

- A **Cosmos DB artifact** created in Microsoft Fabric.
- Two containers:
    - **CCTransactions** – Stores credit card transaction records. 
        Indexing Policy
        {
        "path": "/embedding",
        "type": "DiskANN",
        "dimensions": 1536,
        "metric": "cosine",
        "quantizationByteSize": 4,
        "indexingSearchListSize": 128,
        "vectorIndexShardKey": ["/card_id"]
        }
        data type: float 32
    - **CreditCards** – Stores credit card details.
    - **PendingCCTransactions** - Stores pending credit card transactions.
 
- An **OpenAI endpoint and key** for generating embeddings (placeholders will be used in this sample).
- Installed required Python packages.
- An **Enviroment** set up with the following libraries added:
    - [azure-cosmos-spark_3-5_2-12-4.41.0.jar](https://repo1.maven.org/maven2/com/azure/cosmos/spark/azure-cosmos-spark_3-5_2-12/4.41.0/azure-cosmos-spark_3-5_2-12-4.41.0.jar)
    - [ fabric-cosmos-spark-auth_3-1.1.0.jar](https://repo1.maven.org/maven2/com/azure/cosmos/spark/fabric-cosmos-spark-auth_3/1.1.0/fabric-cosmos-spark-auth_3-1.1.0.jar)
    


In [None]:
# Install the required python modules
%pip install azure-core azure-cosmos
%pip install openai

### Imports and Configuration ###

Set up imports and define configuration values for Cosmos DB and OpenAI. Replace placeholder strings with your actual values when running in your environment.

In [None]:
#Imports and config values
import base64, json
import openai
import os
import uuid
import random
import json
import time
import math
import numpy as np
from datetime import datetime, timezone

from typing import Any, Optional, List, Dict, Tuple

#from azure.cosmos.aio import CosmosClient why aio
from azure.cosmos import CosmosClient, PartitionKey, ThroughputProperties
from azure.core.credentials import TokenCredential, AccessToken


COSMOS_ENDPOINT = '<COSMOS_ENDPOINT>' # The Cosmos DB artifact endpoint from the artifact settings
COSMOS_DATABASE_NAME = '<COSMOS_DATABASE_NAME>' # The Cosmos DB artifact name is the database name
COSMOS_TRANSACTION_CONTAINER_NAME = 'CCTransactions'
COSMOS_PENDING_TRANSACTION_CONTAINER_NAME = 'PendingCCTransactions'
COSMOS_CC_CONTAINER_NAME = 'CreditCards'

os.environ["OPENAI_API_VERSION"] = "2023-05-15"
OPEN_AI_MODEL = "text-embedding-ada-002"


# Spark configs for Cosmos connector using Azure AD (no keys)
base_read_cfg = {
    "spark.cosmos.auth.type" : "AccessToken",
    "spark.cosmos.accountEndpoint": COSMOS_ENDPOINT,
    "spark.cosmos.database": COSMOS_DATABASE_NAME,
    "spark.cosmos.read.inferSchema.enabled": "false",
    "spark.cosmos.read.consistencyStrategy" : "LOCAL_COMMITTED",
    "spark.cosmos.auth.aad.audience" : "https://cosmos.azure.com/.default",
    "spark.cosmos.accountDataResolverServiceName" : "com.azure.cosmos.spark.fabric.FabricAccountDataResolver",
    "spark.cosmos.useGatewayMode" : "true",
    # Optional: app name for diagnostics
    "spark.cosmos.applicationName": "pending-to-cc-stream",
}


### Authentication Class ###

Use a custom credential class to authenticate securely with Cosmos DB using Fabric tokens.

In [None]:
## Authentication Class

class FabricTokenCredential(TokenCredential):
    """Token credential for Fabric Cosmos DB access with automatic refresh and retry logic."""
    
    def get_token(self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None,
                  enable_cae: bool = False, **kwargs: Any) -> AccessToken:
        access_token = notebookutils.credentials.getToken("https://cosmos.azure.com/.default")
        parts = access_token.split(".")
        if len(parts) < 2:
            raise ValueError("Invalid JWT format")
        payload_b64 = parts[1]
        # Fix padding
        padding = (-len(payload_b64)) % 4
        if padding:
            payload_b64 += "=" * padding
        payload_json = base64.urlsafe_b64decode(payload_b64.encode("utf-8")).decode("utf-8")
        payload = json.loads(payload_json)
        exp = payload.get("exp")
        if exp is None:
            raise ValueError("exp claim missing in token")
        return AccessToken(token=access_token, expires_on=exp)

### Initialize Cosmos DB Clients ###

Create clients for the database and containers.

In [None]:
# Initialize Cosmos DB cosmos client
COSMOS_CLIENT = CosmosClient(COSMOS_ENDPOINT, FabricTokenCredential())

# Initialize Cosmos DB database client
DATABASE_CLIENT = COSMOS_CLIENT.get_database_client(COSMOS_DATABASE_NAME)

# Intialize Cosmos DB container client
txns_container = DATABASE_CLIENT.get_container_client(COSMOS_TRANSACTION_CONTAINER_NAME) 
card_container = DATABASE_CLIENT.get_container_client(COSMOS_CC_CONTAINER_NAME)
pending_txns_container = DATABASE_CLIENT.get_container_client(COSMOS_PENDING_TRANSACTION_CONTAINER_NAME)


In [None]:
# Weights per your formula
W_AMOUNT   = 0.2
W_MERCHANT = 0.3
W_LOCATION = 0.5

#### Embedding Helper
Wraps a call to the embeddings API and returns a NumPy vector.

In [None]:
from functools import lru_cache
# ─────────────────────────────────────────────
# Embedding helper
# ─────────────────────────────────────────────
def embed_text(text: str) -> np.ndarray:
    resp = openai.embeddings.create(input=text, model=OPEN_AI_MODEL)
    return np.array(resp.data[0].embedding, dtype=np.float32)


@lru_cache(maxsize=5000)
def embed_text_cached(text: str) -> np.ndarray:
    return embed_text(text)

# ─────────────────────────────────────────────
# Combine embedding (amount + merchant + location)
# ─────────────────────────────────────────────

def normalize_amount(amount: float, lo: float, hi: float) -> float:
    span = max(hi - lo, 1e-6)
    return float(np.clip((amount - lo) / span, 0.0, 1.0))

def make_embedding(merchant: str, location: str, amount: float, lo: float, hi: float) -> list:
    amount_norm = normalize_amount(amount, lo, hi)
    a_vec = np.array([amount_norm], dtype=np.float32) * W_AMOUNT
    m_vec = embed_text(merchant) * W_MERCHANT
    l_vec = embed_text(location) * W_LOCATION
    combined = np.concatenate([a_vec, m_vec, l_vec]).astype(np.float32)
    norm = np.linalg.norm(combined)
    if norm > 0:
        combined /= norm
    return combined.tolist()


# ─────────────────────────────────────────────
# Generate and insert a transaction
# ─────────────────────────────────────────────
def add_transaction(card_id: str, customer_id: str, merchant: str, location: str,
                    amount: float, lo: float, hi: float):
    emb = make_embedding(merchant, location, amount, lo, hi)
    doc = {
        "id": str(uuid.uuid4()),
        "type": "transaction",
        "card_id": card_id,
        "customer_id": customer_id,
        "merchant": merchant,
        "location": location,
        "amount": amount,
        "embedding": emb,
        "timestamp": datetime.now(timezone.utc).isoformat()
    }
    return doc

### Fraud Detection Logic Overview ###
The is_fraudulent_transaction function implements a vector-based anomaly detection approach scoped to a single credit card. It uses historical transaction embeddings to determine whether a new transaction is suspicious.

**Step-by-Step Process**


- **Provisional Amount Band**

    - Compute an initial (lo, hi) range for the transaction amount using _provisional_band. 
    - This stabilizes normalization before we have enough historical neighbors.



- **Generate Provisional Embedding**

    - Build an embedding for the candidate transaction using merchant, location, amount, and the provisional band. 
    - Convert to list[float] for Cosmos DB vector query compatibility.



- **Fetch Nearest Neighbors**

    - Query the CCTransactions container for TOP-K vector neighbors within the same card partition using VectorDistance.
    - No merchant restriction; purely card-wide similarity.



- **Refine Amount Band**

    - If neighbors exist, compute a robust (lo, hi) band using P5 and P95 percentiles of historical amounts.
    - Fallback to provisional band if no neighbors.



- **Rebuild Embedding**

    - Generate a new embedding using the refined band for better normalization.



- **Prepare Neighbor Embeddings**
    - Collect embeddings from neighbors into a NumPy array for distance calculations.



- **Decision Logic**

    - If fewer than 5 historical transactions exist, return False (not enough data to judge).
    - Compute the centroid of neighbor embeddings and a dynamic threshold:
        
        - Centroid Calculation
            $$
                \text{centroid} = \frac{1}{N} \sum_{i=1}^{N} \mathbf{x}_i
            $$
        - Distance for each neighbor
            $$
                d_i = \|\mathbf{x}_i - \text{centroid}\|_2
            $$
        - Threshold Formula
            $$
                \text{threshold} = \mu_d + (\text{multiplier} \times \sigma_d)
            $$
        - Fraud Decision Rule
            $$
                \|\mathbf{x}_{\text{new}} - \text{centroid}\|_2 > \text{threshold}
            $$
        - Provisional Band
            $$
                \text{span} = 10.0 + 0.25 \cdot \log(1 + a) \cdot a^{0.25}
            $$
            $$
                \text{lo} = \max(0.01, a - \text{span}), \quad \text{hi} = a + \text{span}
            $$
        

    - Calculate the L2 distance of the new embedding to the centroid.
    - **Flag as fraud if**:
       - $$
            \text{new\_dist} > \text{threshold}
         $$


In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# Boolean-only fraud check using card-wide vector neighbors (no merchant filter)
# Returns: True if anomalous/fraud, False otherwise
# ─────────────────────────────────────────────────────────────────────────────


def _provisional_band(amount: float) -> Tuple[float, float]:
    """
    Cold-start band so amount normalization is stable before we have neighbors.
    Same shape as your earlier fallback; gentle growth with amount.
    """
    a = max(1.0, float(amount))
    span = 10.0 + 0.25 * np.log1p(a) * a ** 0.25
    lo = max(0.01, a - span)
    hi = a + span
    return float(lo), float(hi)

def _vector_neighbors_for_card(
    *,
    card_id: str,
    query_embedding: List[float],
    k: int = 100
) -> List:
    """
    Returns up to K nearest neighbors by cosine distance from the given query embedding,
    partition-scoped to the card (no merchant restriction).
    """
    query = """
    SELECT TOP @k c.embedding, c.amount, c.timestamp
    FROM c
    WHERE c.card_id = @card
      AND c.type = 'transaction'
    ORDER BY VectorDistance(c.embedding, @emb)
    """
    params = [
        {"name": "@k",    "value": int(k)},
        {"name": "@card", "value": card_id},
        {"name": "@emb",  "value": query_embedding},
    ]
    return list(txns_container.query_items(
        query=query,
        parameters=params,
        partition_key=card_id,
        enable_cross_partition_query=False
    ))

def _robust_amount_band_from_neighbors(neighbors: List[Dict]) -> Optional[Tuple[float, float]]:
    """
    Robust (lo, hi) from neighbor amounts using percentiles (P5, P95).
    """
    if not neighbors:
        return None
    amts = [float(n["amount"]) for n in neighbors if "amount" in n]
    if not amts:
        return None
    lo = float(np.percentile(amts, 5))
    hi = float(np.percentile(amts, 95))
    if hi <= lo:
        hi = lo + max(5.0, 0.1 * max(1.0, lo))
    return lo, hi

def _centroid_threshold(
    embeddings: np.ndarray,
    multiplier: float = 2.5
) -> Tuple[np.ndarray, float]:
    """
    Centroid and threshold = mean(dist) + k * std(dist),
    where dist is L2 distance to centroid.
    """
    centroid = embeddings.mean(axis=0)
    dists = np.linalg.norm(embeddings - centroid, axis=1)
    mu = float(np.mean(dists))
    sigma = float(np.std(dists))
    return centroid, (mu + multiplier * sigma)

def is_fraudulent_transaction(
    *,
    card_id: str,
    merchant: str,
    location: str,
    amount: float,
    k_neighbors: int = 50,
    multiplier: float = 2.5
) -> bool:
    """
    Returns True if the candidate transaction is anomalous/fraudulent, else False.
    No writes. No merchant restriction. Single vector TOP-K query (partition-scoped).
    """
    # 1) Provisional band → stable normalization before neighbors exist
    prov_lo, prov_hi = _provisional_band(amount)

    # 2) Provisional embedding for the vector search
    new_emb_prov = np.asarray(
        make_embedding(merchant=merchant, location=location, amount=float(amount), lo=prov_lo, hi=prov_hi),
        dtype=np.float32
    ).tolist()  # Cosmos SQL param expects list[float]

    # 3) Fetch TOP-K neighbors within the card partition
    neighbors = _vector_neighbors_for_card(
        card_id=card_id,
        query_embedding=new_emb_prov,
        k=k_neighbors
    )

    # 4) Derive a refined (lo, hi) band from neighbors; fallback to provisional if none
    band = _robust_amount_band_from_neighbors(neighbors)
    lo, hi = (band if band is not None else (prov_lo, prov_hi))

    # 5) Rebuild embedding using the refined (lo, hi)
    new_emb = np.asarray(
        make_embedding(merchant=merchant, location=location, amount=float(amount), lo=lo, hi=hi),
        dtype=np.float32
    )

    
    # 6) Prepare neighbor embeddings
    vecs = []
    for n in neighbors:
        emb = n.get("embedding")
        if isinstance(emb, list) and emb:
            vecs.append(np.asarray(emb, dtype=np.float32))
    hist = np.vstack(vecs) if vecs else np.empty((0, 0), dtype=np.float32)

    # 7) Decision: if not enough history, do NOT flag (return False)
    if hist.size == 0 or hist.shape[0] < 5:
        return False

    centroid, threshold = _centroid_threshold(hist, multiplier=multiplier)
    new_dist = float(np.linalg.norm(new_emb - centroid))
    return new_dist > threshold


#### Quick Fraud Check ####
A Quick check to verify fraud detection is working with Cosmos Containers.

In [None]:

is_anom = is_fraudulent_transaction(
    card_id="C0001",
    merchant="Samsung",
    location="Rhode Island",
    amount=180.00,
    k_neighbors=32,
    multiplier=2.5
)
print("Transaction for Customer 1 _is_fraudulent:", is_anom)

is_anom = is_fraudulent_transaction(
    card_id="C0002",
    merchant="Subway",
    location= "California",
    amount= 30.00,
    k_neighbors=32,
    multiplier=2.5)

print("Transaction for Customer 2 _is_fraudulent:", is_anom)

### Change Feed Parsing and Schema Enforcement ###

- **Define Transaction Schema** as it should match the expected values from the pending transaction items.

- **Read Change Feed Stream**: We use the Cosmos DB Spark connector to read the incremental change feed from the COSMOS_PENDING_TRANSACTION_CONTAINER_NAME container.

- **Parse Raw Body**: 
    - Parse _rawBody JSON into structured columns using from_json and txn_schema.
    - Preserve metadata (id, _ts, _lsn, _etag).

In [None]:

from pyspark.sql.functions import col, struct, to_json, from_json
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, BooleanType, TimestampType


txn_schema = StructType([
    StructField("id", StringType()),                # Document ID
    StructField("type", StringType()),              # e.g., "transaction"
    StructField("card_id", StringType()),
    StructField("customer_id", StringType()),
    StructField("merchant", StringType()),
    StructField("location", StringType()),
    StructField("amount", DoubleType()),
    StructField("timestamp", StringType())
])


changefeed_df: DataFrame = (
    spark.readStream
         .format("cosmos.oltp.changeFeed")
         .options(**{
             **base_read_cfg,
             "spark.cosmos.container": COSMOS_PENDING_TRANSACTION_CONTAINER_NAME,
             "spark.cosmos.changeFeed.mode": "Incremental",
             "spark.cosmos.changeFeed.startFrom": "Now",
         })
         .load()
)

# Filter tombstones if column exists (schema can vary)
cf_cols = set(changefeed_df.columns)
if "_isDelete" in cf_cols:
    changefeed_df = changefeed_df.filter(col("_isDelete") == False)

incoming_df = (
    changefeed_df
      .select(
          from_json(col("_rawBody").cast("string"), txn_schema).alias("doc"),
          col("id"),
          col("_ts"),
          col("_lsn"),
          col("_etag"),
      )
      .select("doc.*", "id", "_ts", "_lsn", "_etag")
)

# Ensure essential fields exist
cf_cols = set(incoming_df.columns)
required_cols = ["id", "card_id", "customer_id", "merchant", "location", "amount"]
for c in required_cols:
    if c not in cf_cols:
        raise ValueError(f"Expected column '{c}' not found in change feed schema: {sorted(cf_cols)}")


### foreachBatch for Fraud Detection and Card Locking ###

This step processes each micro-batch from the Cosmos DB change feed. It performs three key actions:

- **Lock Check**: Before any fraud logic runs, the code queries the CreditCards container to skip transactions for cards already marked as locked. It prints the lock reason and timestamp.
- **Fraud Detection**: For unlocked cards, it calls is_fradulent_transaction to determine if the transaction is anomalous based on historical embeddings.
- **Conditional Actions**:
    - If fraudulent, Patch the card status to locked and record metadata.
    - If clean, Upsert the transaction into the CCTransactions container.

In [None]:
# Consolidated foreachBatch with your FabricTokenCredential and controlled concurrency
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone

def _process_batch(batch_df, batch_id: int):
    from azure.cosmos import CosmosClient

    client = COSMOS_CLIENT
    cards  = card_container     # for patching
    cc_tx  = txns_container      # for neighbor search & inserts

    now_iso = datetime.now(timezone.utc).isoformat()
    patch_ops = [
        {"op": "set", "path": "/status",            "value": "locked"},
        {"op": "set", "path": "/last_lock_reason", "value": "locked for fraudulent transaction"},
        {"op": "set", "path": "/last_updated",     "value": now_iso},
    ]

    rows = list(batch_df.toLocalIterator())
    if not rows:
        print(f"[Batch {batch_id}] No rows.")
        return

    def handle_row(r) -> tuple[bool, str]:
        d = r.asDict(recursive=True)
        card_id     = d.get("card_id")
        customer_id = d.get("customer_id")
        merchant    = d.get("merchant")
        location    = d.get("location")
        amount      = float(d.get("amount", 0.0))

        if not card_id or not merchant or not location:
            return (False, "skip-missing-fields")

        
        if not customer_id:
            return (False, "skip-missing-customer-id")

        lock_check_query = """
            SELECT TOP 1 c.id, c.status, c.last_lock_reason, c.last_updated
            FROM c
            WHERE c.customer_id = @cid AND c.card_id = @card
        """
        lock_check_params = [{"name": "@cid", "value": customer_id},
                             {"name": "@card", "value": card_id}]
        try:
            lock_docs = list(cards.query_items(
                query=lock_check_query,
                parameters=lock_check_params,
                partition_key=customer_id,
                enable_cross_partition_query=False
            ))
            if lock_docs:
                card_meta = lock_docs[0]
                if str(card_meta.get("status", "")).lower() == "locked":
                    reason = card_meta.get("last_lock_reason", "(no reason recorded)")
                    when   = card_meta.get("last_updated", "(no timestamp recorded)")
                    print(f"[Batch {batch_id}] Card {card_id} is LOCKED — reason: {reason} — last_updated: {when}. Skipping transaction.")
                    return (False, "skip-locked")
        except Exception as e:
            print(f"[WARN] lock check failed (cust={customer_id}, card={card_id}): {e}")
            # Conservative choice: if we can't verify lock status, continue to normal flow
            return (False, "lock-check-error")


        # Decide
        try:
            is_fraud = is_fraudulent_transaction(
                card_id=card_id,
                merchant=merchant,
                location=location,
                amount=amount,
                k_neighbors=50,          # start smaller for RU efficiency
                multiplier=2.5         
            )
        except Exception as e:
            print(f"[WARN] decision error (card={card_id}): {e}")
            return (False, "decision-error")

        if is_fraud:
            # Patch CreditCards inside the customer partition
            print(f"[Batch {batch_id}] Fraud detected for card {card_id} at merchant {merchant}, location {location}")
            if not customer_id:
                return (False, "fraud-no-customer-id")
            # Query IDs within partition
            query = """
                SELECT c.id
                FROM c
                WHERE c.customer_id = @cid AND c.card_id = @card
            """
            params = [{"name": "@cid", "value": customer_id}, {"name": "@card", "value": card_id}]
            try:
                results = list(cards.query_items(
                    query=query,
                    parameters=params,
                    partition_key=customer_id,
                    enable_cross_partition_query=False
                ))
                for doc in results:
                    # patch with retry for 429
                    delay = 0.5
                    for _ in range(6):
                        try:
                            cards.patch_item(item=doc["id"], partition_key=customer_id, patch_operations=patch_ops)
                            break
                        except Exception as ex:
                            msg = str(ex).lower()
                            if "429" in msg or "rate is large" in msg:
                                time.sleep(delay); delay = min(delay*2, 8.0); continue
                            print(f"[WARN] patch failed (cust={customer_id}, card={card_id}): {ex}")
                            break
                print(f"[Batch {batch_id}] Card {card_id} has been locked due to fraudulent transaction")
                return (True, "patched")
            except Exception as e:
                print(f"[WARN] patch query failed (cust={customer_id}, card={card_id}): {e}")
                return (False, "patch-error")
        else:
            # Upsert the clean transaction into CCTransactions (partition key = card_id)
            try:
                # Ensure required fields present and consistent casing
                d["customer_id"] = customer_id
                d.pop("partition_key", None)
                cc_tx.upsert_item(d)
                print(f"[Batch {batch_id}] Transaction upserted for card {card_id}")
                return (True, "upserted")
            except Exception as e:
                msg = str(e).lower()
                if "429" in msg or "rate is large" in msg:
                    # simple retry
                    delay = 0.5
                    for _ in range(6):
                        try:
                            cc_tx.upsert_item(d)
                            print(f"[Batch {batch_id}] Transaction upserted for card {card_id}")
                            return (True, "upserted-retry")
                        except Exception as ex:
                            if "429" in str(ex).lower():
                                time.sleep(delay); delay = min(delay*2, 8.0); continue
                            print(f"[WARN] upsert failed (card={card_id}): {ex}")
                            break
                else:
                    print(f"[WARN] upsert failed (card={card_id}): {e}")
                return (False, "upsert-error")

    # I/O-bound → safe to use threads
    max_workers = min(32, max(4, len(rows)))
    ok = 0
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        futures = [ex.submit(handle_row, r) for r in rows]
        for f in as_completed(futures):
            success, _ = f.result()
            ok += 1 if success else 0

    print(f"[Batch {batch_id}] processed={len(rows)} , success={ok}")

In [None]:
consolidated_query = (
    incoming_df   # from your change feed read
      .writeStream
      .foreachBatch(_process_batch)
      .start()
)

In [None]:

# Monitor & control (single consolidated stream)
print("Consolidated query id:", consolidated_query.id)
consolidated_query.awaitTermination()

# Optional helpers:
# consolidated_query.stop()
