In [3]:
from __future__ import annotations
from lxml import etree
from datetime import datetime, date
import re
from collections import defaultdict
from typing import Union
import os
import pandas as pd
from rapidfuzz import fuzz, process
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
from IPython.display import Image, display, HTML
import plotly.io as pio


XBRLI_NS = "http://www.xbrl.org/2003/instance"
XBRLDI_NS = "http://xbrl.org/2006/xbrldi"
XLINK_NS = "http://www.w3.org/1999/xlink"

# predefined words
REVENUE_KEYWORDS = [
    "revenue",
    "revenues",
    "sales",
    "netrevenue",
    "netsales",
]

GAAP_REVENUE_TAGS = {
    "us-gaap:Revenues",
    "us-gaap:RevenueFromContractWithCustomerExcludingAssessedTax",
    "us-gaap:StatementBusinessSegmentsAxis"
}


LINK_NS  = "http://www.xbrl.org/2003/linkbase"
XLINK_NS = "http://www.w3.org/1999/xlink"


In [4]:


def detect_xml_file_type(file_path: str) -> str:
    """
    Detect what kind of XBRL/XML/HTML file this is:
      - 'instance' : contains <xbrli:xbrl> or inline XBRL (<ix:nonFraction>, <ix:nonNumeric>)
      - 'linkbase' : contains link:definitionLink or link:labelLink
      - 'schema'   : contains <xsd:schema>
      - 'htm'      : HTML file (no XBRL found)
      - 'other'    : unreadable or not XBRL-related
    """

    # Quick extension check
    lower_name = file_path.lower()
    if lower_name.endswith(".xsd"):
        return "schema"

    # Open as text (binary-safe)
    try:
        with open(file_path, "rb") as f:
            # read up to 50 MB in chunks to cover deep inline XBRL
            content = b""
            chunk_size = 5 * 1024 * 1024  # 5 MB
            while len(content) < 50 * 1024 * 1024:
                chunk = f.read(chunk_size)
                if not chunk:
                    break
                content += chunk
    except Exception:
        return "other"

    data = content.lower()

    # instance detection (normal or inline XBRL)
    if (
        b"<xbrli:xbrl" in data
        or b"http://www.xbrl.org/2003/instance" in data
        or b"<ix:nonfraction" in data
        or b"<ix:nonnumeric" in data
    ):
        return "instance"

    if b"<link:definitionlink" in data or b"<link:labellink" in data:
        return "linkbase"

    if b"<xsd:schema" in data:
        return "schema"
    if lower_name.endswith((".htm", ".html")):
        return "htm"
    return "other"

def _is_segment_axis(qname: str) -> bool:
    """Heuristic: treat segment/product/service axes as segment-like; exclude common noise."""
    if not qname:
        return False
    local = qname.split(":", 1)[-1].lower()

    include = [
        # "statementbusinesssegmentsaxis",  # us-gaap
        "operatingsegmentsaxis",      # us-gaap (older)
        "productorserviceaxis",       # srt
        "productsandservicesaxis",      # us-gaap (alt)
        "businesssegmentaxis",
        "reportablesegmentaxis"
    ]
    exclude = [
        "geographical", "geography", "region", "country", "area",
        "consolidation", "majorcustomer", "customer",
        "concentrationrisk", "benchmark", "typeaxis",
        "range",
    ]
    return any(k in local for k in include) and not any(k in local for k in exclude)

def _frag_to_qname_like(href: str) -> str | None:
    if not href:
        return None
    frag = href.split("#", 1)[1] if "#" in href else href
    if ":" in frag:
        return frag
    if "_" in frag:  # "prefix_local" -> "prefix:local"
        p, l = frag.split("_", 1)
        return f"{p}:{l}"
    return frag

def extract_gaap_members(def_xml_path: str) -> dict[str, list[str]]:
    tree = etree.parse(def_xml_path)
    ns = {"link": LINK_NS, "xlink": XLINK_NS}

    axis_to_members: dict[str, set[str]] = {}

    for dlink in tree.xpath(".//link:definitionLink", namespaces=ns):
        # 1) label -> qname
        label_to_qname: dict[str, str] = {}
        for loc in dlink.xpath("./link:loc", namespaces=ns):
            lbl  = loc.get(f"{{{XLINK_NS}}}label")
            href = loc.get(f"{{{XLINK_NS}}}href")
            if not lbl or not href:
                continue
            qn = _frag_to_qname_like(href)
            if qn:
                label_to_qname[lbl] = qn

        # 2) Axis (from) -> Domain (to)
        domain_label_to_axis_qn: dict[str, str] = {}
        for arc in dlink.xpath("./link:definitionArc", namespaces=ns):
            arcrole = (arc.get(f"{{{XLINK_NS}}}arcrole") or "").lower()
            if "dimension-domain" not in arcrole:
                continue
            frm = arc.get(f"{{{XLINK_NS}}}from")
            to  = arc.get(f"{{{XLINK_NS}}}to")
            if not frm or not to:
                continue
            axis_qn = label_to_qname.get(frm)
            if not axis_qn:
                continue
            if _is_segment_axis(axis_qn):
                # store the DOMAIN *label* -> AXIS qname
                domain_label_to_axis_qn[to] = axis_qn

        # 3) Domain (from) -> Member (to)
        for arc in dlink.xpath("./link:definitionArc", namespaces=ns):
            arcrole = (arc.get(f"{{{XLINK_NS}}}arcrole") or "").lower()
            if "domain-member" not in arcrole:
                continue
            frm = arc.get(f"{{{XLINK_NS}}}from")  # domain *label* (or nested member)
            to  = arc.get(f"{{{XLINK_NS}}}to")    # member label
            if not frm or not to:
                continue

            axis_qn = domain_label_to_axis_qn.get(frm)
            if not axis_qn:
                # If frm is a nested member under a top domain, skip (keeps only top-level segments)
                continue
            member_qn = label_to_qname.get(to)
            if member_qn:
                axis_to_members.setdefault(axis_qn, set()).add(member_qn)

    # 4) RETURN: do NOT re-filter by a hardcoded TARGET_AXES; rely on _is_segment_axis above
    return {axis_qn: sorted(members) for axis_qn, members in axis_to_members.items()}

def _is_revenue_concept(qname: str) -> bool:
    qn_lower = qname.lower()
    return any(k in qn_lower for k in REVENUE_KEYWORDS) or qname in GAAP_REVENUE_TAGS


def _parse_iso_date(s: str | None) -> date | None:
    if not s:
        return None
    t = s.strip()
    if not t:
        return None
    try:
        return datetime.fromisoformat(t[:10]).date()
    except Exception:
        try:
            return datetime.strptime(t[:10], "%Y-%m-%d").date()
        except Exception:
            return None


def _parse_contexts(root) -> dict[str, dict]:
    """Return context_id → {start,end,instant,members} where members is a list of explicit member qnames."""
    ns = {"xbrli": XBRLI_NS, "xbrldi": XBRLDI_NS}
    contexts: dict[str, dict] = {}
    for ctx in root.xpath("//xbrli:context", namespaces=ns):
        ctx_id = ctx.get("id")
        if not ctx_id:
            continue
        period = ctx.find("{http://www.xbrl.org/2003/instance}period")
        start = end = instant = None
        if period is not None:
            start = period.findtext("{http://www.xbrl.org/2003/instance}startDate")
            end = period.findtext("{http://www.xbrl.org/2003/instance}endDate")
            instant = period.findtext("{http://www.xbrl.org/2003/instance}instant")
        members: list[str] = []
        for mem in ctx.xpath(".//xbrldi:explicitMember", namespaces=ns):
            member_text = (mem.text or "").strip()
            if member_text:
                members.append(member_text)
        contexts[ctx_id] = {"start": start, "end": end, "instant": instant, "members": members}
    return contexts


def _choose_latest_annual_contexts(contexts: dict[str, dict], cur_date: date) -> set[str]:
    # Collect contexts that have at least an end date
    items: list[tuple[str, date, int | None]] = []  # (ctx_id, end_date, span_days_or_none)
    for cid, c in contexts.items():
        ed = _parse_iso_date(c.get("end"))
        if not ed:
            continue
        sd = _parse_iso_date(c.get("start"))
        span_days = (ed - sd).days if sd else None
        items.append((cid, ed, span_days))

    if not items:
        return set()

    # Focus on contexts ending at (or very near) the latest end date
    latest_end = cur_date
    at_latest = [(cid, ed, span) for (cid, ed, span) in items if (latest_end - ed).days <= 123]

    return {cid for (cid, _ed, _span) in at_latest}

def _choose_latest_quarter_contexts(contexts: dict[str, dict], cur_date:date) -> set[str]:
    # Consider only duration contexts with an end date
    items: list[tuple[str, date, date]] = []  # (ctx_id, start_date, end_date)
    for cid, c in contexts.items():
        sd = _parse_iso_date(c.get("start"))
        ed = _parse_iso_date(c.get("end"))
        if sd and ed:
            items.append((cid, sd, ed))

    if not items:
        return set()

    latest_end = cur_date
    at_latest = [(cid, sd, ed) for (cid, sd, ed) in items if (latest_end - ed).days <= 31 ]

    def days(sd: date, ed: date) -> int:
        return (ed - sd).days

    # Prefer true quarter-like spans
    quarter_like = [cid for (cid, sd, ed) in at_latest if 60 <= days(sd, ed) <= 100]
    if quarter_like:
        return set(quarter_like)

    # Otherwise pick the smallest duration > 30 days at latest end (avoid 1-day)

    spans = [(cid, days(sd, ed)) for (cid, sd, ed) in at_latest if days(sd, ed) > 30]

    if spans:
        min_span = min(d for (_, d) in spans)
        return {cid for (cid, d) in spans if d == min_span}

    return {}


In [14]:
def print_tree(elem, level=0):
    """for debug purpose
    """
    indent = "  " * level
    print(f"{indent}<{elem.tag}>")
    for k, v in elem.attrib.items():
        print(f"{indent}  @{k} = {v}")
    if elem.text and elem.text.strip():
        print(f"{indent}  text = {elem.text.strip()}")
    for child in elem:
        print_tree(child, level + 1)

def _parse_float(text: str | None) -> float | None:
    if text is None:
        return None
    s = text.strip()
    if not s:
        return None
    s = s.replace(",", "")
    if s.startswith("(") and s.endswith(")"):
        s = "-" + s[1:-1]
    try:
        return float(s)
    except Exception:
        return None




def parse_labels(lab_path):
    ns = {
        "link": "http://www.xbrl.org/2003/linkbase",
        "xlink": "http://www.w3.org/1999/xlink"
    }
    tree = etree.parse(lab_path)

    # Preferred label roles (general, not overfitted)
    PREFERRED_LABEL_ROLES = [
        "http://www.xbrl.org/2003/role/terseLabel",
        "http://www.xbrl.org/2003/role/label",
        "http://www.xbrl.org/2003/role/verboseLabel",
        "http://www.xbrl.org/2003/role/documentation",
    ]

    # --- Step 1. Collect all label resources ---
    labels = defaultdict(list)
    for lbl in tree.findall(".//link:label", ns):
        lbl_id = lbl.get("{http://www.w3.org/1999/xlink}label")
        role = lbl.get("{http://www.w3.org/1999/xlink}role")
        lang = lbl.get("{http://www.w3.org/XML/1998/namespace}lang", "").lower()
        text = (lbl.text or "").strip()
        if lbl_id and text:
            labels[lbl_id].append({
                "text": text,
                "role": role,
                "lang": lang
            })

    # --- Step 2. Collect all loc (concept anchors) ---
    locs = {}
    for loc in tree.findall(".//link:loc", ns):
        loc_id = loc.get("{http://www.w3.org/1999/xlink}label")
        href = loc.get("{http://www.w3.org/1999/xlink}href")
        if loc_id and href and "#" in href:
            concept = href.split("#")[-1].split(":")[-1]
            locs[loc_id] = concept


    # --- Step 3. Traverse labelArcs (connect loc → label) ---
    label_map = {}
    for arc in tree.findall(".//link:labelArc", ns):
        frm = arc.get("{http://www.w3.org/1999/xlink}from")
        to = arc.get("{http://www.w3.org/1999/xlink}to")
        concept = locs.get(frm)
        if not concept or to not in labels:
            continue

        
        # Pick label text by priority (role first, then language)
        chosen = None
        for role in PREFERRED_LABEL_ROLES:
            # filter labels of this role
            candidates = [l for l in labels[to] if l["role"] == role]
            # prioritize English if available
            if candidates:
                chosen = next((l for l in candidates if l["lang"].startswith("en")), candidates[0])
                break

        if not chosen:
            chosen = labels[to][0]  # fallback

        # Assign only if not already mapped (first wins)
        label_map.setdefault(concept.replace("_", ":", 1), chosen["text"])

    return label_map

In [15]:
def extract_segment_values(instance_path: str, K: bool, cur_date:date, lab_map: dict[str, str] = None, mode = 'xml') -> dict[str, float]:
    """
    Extract revenue facts by segment members for the latest quarter in the file.
    Calls extract_gaap_members() to get members automatically.
    """

    parser = etree.XMLParser(recover=True, huge_tree=True)
    tree = etree.parse(instance_path, parser)
    root = tree.getroot()

    contexts = _parse_contexts(root)
    if K:
        selected_ctx_ids = _choose_latest_annual_contexts(contexts, cur_date)
    else:
        selected_ctx_ids = _choose_latest_quarter_contexts(contexts, cur_date)


    inst_dates = {}
    for cid, c in contexts.items():
        inst = c.get("instant")
        dt = _parse_iso_date(inst)

        if dt:
            inst_dates.setdefault(dt, set()).add(cid)


    if inst_dates:
        latest_inst = max(inst_dates.keys())

        for k in inst_dates.keys():
            if (latest_inst - k).days <= 100:
                if inst_dates[k].difference(selected_ctx_ids):
                    selected_ctx_ids = selected_ctx_ids.union(inst_dates[k])


    facts: dict[str, dict[str, tuple[int, float]]] = defaultdict(dict)  # member -> (priority, value)
    membertable = {}  # member -> axis

    units : dict[str, tuple[str, Union[int, None]]] = defaultdict(str)
    for el in root.iter():
        if mode == "xml":
            qn = etree.QName(el)
            if qn.namespace == XBRLI_NS:
                continue
            prefix = el.prefix
            local = qn.localname
            qname = f"{prefix}:{local}" if prefix else local

        elif mode == "htm":
            if not isinstance(el.tag, str):
                continue
            qn = etree.QName(el.tag)
            if qn.namespace == XBRLI_NS:
                continue
            qname = el.get("name")
            if not qname:
                continue

        if not _is_revenue_concept(qname):
            continue

        ctx_id = el.get("contextRef")

        if not ctx_id or ctx_id not in selected_ctx_ids:
            continue

        # skip nil facts
        nil_attr = el.get("{http://www.w3.org/2001/XMLSchema-instance}nil")

        if nil_attr and nil_attr.strip().lower() in {"true", "1", "yes"}:
            continue

        val = _parse_float(el.text)

        if val is None:
            continue
        scale = None
        unit_ref = el.get("unitRef")
        if mode == "htm":
            scale = el.get("scale")
        # unit = unit_map.get(unit_ref, None)
        members = contexts[ctx_id]["members"]
        for m in members:
            for axis, axis_members in axis_to_members.items():
                units[axis] = (unit_ref, scale)
            # If we have definition-derived members, require a match; otherwise accept any explicit member
                match_def = any(m == sm or m.endswith(":" + sm.split(":")[-1]) for sm in axis_members) if axis_members else True
                if match_def:
                    priority = 2 if qname in GAAP_REVENUE_TAGS else 1
                    if m not in membertable:
                        membertable[m] = axis
                    prev = facts.get(membertable[m])[m] if membertable[m] in facts and m in facts[membertable[m]] else None
                    if (prev is None) or (priority > prev[0]) or (priority == prev[0] and abs(val) > abs(prev[1])):
                        facts[membertable[m]][m] = (priority, val)

    # project to member -> value
    res = defaultdict(dict)
    for axis, mv in facts.items():
        res[axis]['unit'] = units[axis][0]
        if units[axis][1] is not None:
            res[axis]["scale"] = units[axis][1]
        for m, (_p, val) in mv.items():
            if lab_map is not None:
                res[axis][lab_map[m]] = val
            else:
                res[axis][m] = val
    return res

In [17]:

prefix = "data_new"
meta_data ={}
meta_label = defaultdict(dict)
meta_member = defaultdict(dict)
for company_name in os.listdir(prefix):
    print(company_name)
    company_data = {}
    full_path = os.path.join(prefix, company_name)
    for d in os.listdir(full_path):
        K = d.endswith("K")
        ds = d.split("-")
        date_folder = os.listdir(os.path.join(full_path, d))
        cur = {}
        cur_date = None
        for f in date_folder:
            if f.endswith("xsd"):
                m = re.search(r"(\d{8})", f)
                if not m:
                    continue
                cur_date = datetime.strptime(m.group(1), "%Y%m%d").date()
        assert cur_date is not None
        mode = "xml"
        for file in date_folder:
            file_path = os.path.join(full_path, d, file)
            if file_path.endswith('def.xml'):
                cur["def_xml"] = file_path
            elif file.endswith("lab.xml"):
                cur["lab_xml"] = file_path
            else:
                detected_type = detect_xml_file_type(file_path)
                if detected_type == "instance":
                    cur["instance_xml"] = file_path
                    if file_path.endswith("htm"):
                        mode = "htm"
        if "def_xml" not in cur or "instance_xml" not in cur or "lab_xml" not in cur:
            continue
        label_mapping = parse_labels(cur["lab_xml"])
        assert label_mapping
        meta_label[company_name].update({ d: label_mapping})
        axis_to_members = extract_gaap_members(cur["def_xml"])
        meta_member[company_name].update({d : axis_to_members})
        results = extract_segment_values(cur["instance_xml"], K, cur_date, lab_map = None, mode = mode)
        company_data[d] = results

    meta_data[company_name] = company_data

# use latest mapping only
label_map = defaultdict(dict)
for company, d in meta_label.items():
    for v in d.values():
        label_map[company].update(v)  

AAPL
AMZN
GOOGL
MSFT
NVDA
TSLA


In [26]:
# --------------------------------------------------
# 1. Data Normalization
# --------------------------------------------------
def _axis_scale(axis_dict):
    """Return numeric scale (10**exp) if provided, else 1."""
    s = axis_dict.get("scale", None)
    if s is None:
        return 1
    try:
        exp = int(str(s).strip())
        # treat "6" as exponent meaning 10**6
        return 10 ** exp
    except Exception:
        return 1


def derive_quarterly_from_json(data: dict):
    """Flatten JSON and derive Q4 values, normalizing all to dollars."""
    records = []
    for period_key, axes in data.items():
        is_annual = period_key.endswith("K")
        date_str = period_key.replace("K", "")
        d = datetime.strptime(date_str, "%Y-%m-%d").date()
        fiscal_year = d.year

        for axis_name, axis_vals in axes.items():
            unit = axis_vals.get("unit", "").lower()
            scale = _axis_scale(axis_vals)

            for member, val in axis_vals.items():
                if member in ("unit", "scale"):
                    continue
                if member.lower().startswith("us-gaap:"):
                    continue  # remove subtotal members

                try:
                    v = float(val) * scale
                except Exception:
                    continue

                records.append({
                    "period": date_str,
                    "fiscal_year": fiscal_year,
                    "is_annual": is_annual,
                    "axis": axis_name,
                    "member": member,
                    "value_usd": v,
                    "unit": unit or "usd",
                })

    df = pd.DataFrame(records)
    df.sort_values(["fiscal_year", "period"], inplace=True)

    # --- derive Q4 from 10-K ---
    out_rows = []
    for (year, axis, member), sub in df.groupby(["fiscal_year", "axis", "member"]):
        annual = sub[sub["is_annual"]]
        quarters = sub[~sub["is_annual"]].sort_values("period").reset_index(drop=True)
        if len(annual) == 0:
            continue
        annual_val = float(annual["value_usd"].iloc[-1])
        qsum = float(quarters["value_usd"].sum())
        q4_val = max(annual_val - qsum, 0.0)

        # assign Q1–Q3
        for i, r in quarters.iterrows():
            out_rows.append({
                "fiscal_year": year,
                "quarter": i + 1,
                "axis": axis,
                "member": member,
                "value_usd": r["value_usd"],
                "source": "10Q"
            })
        # add derived Q4
        out_rows.append({
            "fiscal_year": year,
            "quarter": 4,
            "axis": axis,
            "member": member,
            "value_usd": q4_val,
            "source": "10K_minus_Q1toQ3"
        })

    qdf = pd.DataFrame(out_rows)
    qdf.sort_values(["fiscal_year", "quarter"], inplace=True)
    return qdf


# --------------------------------------------------
# 2. Member Name Normalization
# --------------------------------------------------
def normalize_member_names(df, threshold=85):

    df["normalized_member"] = df["member"]
    unique_names = df["normalized_member"].unique().tolist()
    final_map = {}

    for name in unique_names:
        if not final_map:
            final_map[name] = name
            continue
        result = process.extractOne(name, final_map.keys(), scorer=fuzz.token_sort_ratio)
        if result is None:
            final_map[name] = name
            continue
        match, score, _ = result
        if score >= threshold:
            final_map[name] = match
        else:
            final_map[name] = name

    df["normalized_member"] = df["normalized_member"].map(final_map)
    return df


# --------------------------------------------------
# 3. Visualization Functions
# --------------------------------------------------
def plot_macro_micro_style(df, company="Amazon", key="AMZN"):
    df = normalize_member_names(df)
    df = df[~df["member"].str.lower().str.startswith("us-gaap:")]
    df["display_name"] = df.apply(
        lambda r: label_map[key].get(r["normalized_member"], r["normalized_member"]),
        axis=1
    )
    display_name = "display_name"
    yearly = (
        df.groupby(["fiscal_year", display_name], as_index=False)["value_usd"].sum()
    )
    yearly["value_mil"] = yearly["value_usd"] / 1e6
    yearly["YoY"] = yearly.groupby(display_name)["value_usd"].pct_change() * 100

    members = yearly[display_name].unique()
    max_year = yearly["fiscal_year"].max()

    fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
                        row_heights=[0.6, 0.4], vertical_spacing=0.05)

    # --- top: stacked bars ---
    for seg in members:
        sub = yearly[yearly[display_name] == seg]
        fig.add_trace(go.Bar(
            x=sub["fiscal_year"],
            y=sub["value_mil"],
            name=seg,
            hovertemplate="%{y:.1f} M USD<br>Year=%{x}<extra></extra>"
        ), row=1, col=1)

    # --- bottom: YoY growth ---
    for seg in members:
        sub = yearly[yearly[display_name] == seg]
        fig.add_trace(go.Scatter(
            x=sub["fiscal_year"],
            y=sub["YoY"],
            mode="lines+markers",
            line=dict(dash="dash"),
            name=f"{seg} (YoY%)",
            hovertemplate="%{y:.1f}% YoY<br>Year=%{x}<extra></extra>"
        ), row=2, col=1)

    fig.update_layout(
        title=f"{company} — Revenue Composition & YoY Growth",
        yaxis_title="Revenue (Millions USD)",
        yaxis2_title="YoY Growth (%)",
        legend=dict(orientation="h", y=-0.15),
        template="plotly_white",
        hovermode="x unified",
        updatemenus=[{
            "buttons": [
                {"args": [{"xaxis.range": [max_year - i, max_year]}],
                 "label": f"{i+1}Y", "method": "relayout"}
                for i in range(1, 6)
            ] + [
                {"args": [{"xaxis.range": [yearly['fiscal_year'].min(),
                                           yearly['fiscal_year'].max()]}],
                 "label": "All", "method": "relayout"}
            ],
            "direction": "right", "x": 0.35, "xanchor": "left", "y": 1.1,
            "yanchor": "top", "type": "buttons",
              "active": 5
        }]
    )
    fig.write_html("example/plot.html")
    fig.show()
    fig.update_layout(
    height=900,  
    width=600,  
)
    pio.renderers.default = "notebook_connected"
    display(HTML(fig.to_html(include_plotlyjs='cdn')))


In [27]:
company = 'AMZN'
KEY = 'AMZN'

data = meta_data[company]
df_quarters = derive_quarterly_from_json(data)
plot_macro_micro_style(df_quarters, key=KEY)