In [323]:
import polars as pl
import itertools
from functools import lru_cache

from pathlib import Path

In [324]:
PROJECTDIR = Path("/home/akash/Main/projects/CACourses")

In [None]:
# functions

def extract_articulations(fp: Path, schema: pl.Schema | None) -> pl.DataFrame:
    uni = int(fp.parts[-2])
    cc  = int(fp.parts[-1].split('to')[0])

    lf = pl.read_json(source=fp, schema=schema).lazy()

    # Normalize structure (Explode list vs Rename single)
    if "prefixes" in str(fp):
        lf = lf.explode("articulations")
    else:
        lf = lf.rename({"articulation": "articulations"})

    return (
        lf
        # 1. Filter empty articulations immediately
        .filter(
            pl.col("articulations")
            .struct.field("sendingArticulation")
            .struct.field("items")
            .list.len() > 0
        )
        # 2. Extract Fields & Merge Source IDs
        .select(
            # Extract Series IDs (List[Int]) and Single IDs (Int)
            series_ids=pl.col("articulations").struct.field("series").struct.field("courses")
                     .list.eval(pl.element().struct.field("courseIdentifierParentId")),
            
            root_id=pl.col("articulations").struct.field("course").struct.field("courseIdentifierParentId"),

            # Extract Destination Data
            sending_items=pl.col("articulations").struct.field("sendingArticulation").struct.field("items"),
            
            # Global Conjunction
            global_conj=(
                pl.col("articulations")
                .struct.field("sendingArticulation")
                .struct.field("courseGroupConjunctions")
                .list.first()
                .struct.field("groupConjunction")
                .fill_null("Or")
            )
        )
        # 3. Safe Explode Logic
        # Coalesce series list with root_id (wrapped in a list) so we never drop rows
        .with_columns(
            source_id_list=pl.coalesce(
                pl.col("series_ids"), 
                pl.concat_list(pl.col("root_id")) 
            )
        )
        .explode("source_id_list") # Now safe to explode
        
        # 4. Final Construction
        .select(
            cc=pl.lit(cc),
            uni=pl.lit(uni),
            course_id=pl.col("source_id_list"),
            articulation=pl.struct(
                conj=pl.col("global_conj"),
                items=pl.col("sending_items").list.eval(
                    pl.struct(
                        conj=pl.element().struct.field("courseConjunction"),
                        items=pl.element().struct.field("items").list.eval(
                            pl.element().struct.field("courseIdentifierParentId")
                        )
                    )
                )
            )
        )
        .group_by(
            pl.col("course_id"),
            pl.col("cc"),
            pl.col("uni")
        ).all()
        .select(
            course_id=pl.col("course_id"),
            cc=pl.col("cc"),
            uni=pl.col("uni"),
            articulation=pl.struct(
                conj=pl.lit("Or"),
                items=pl.col("articulation")
            )
        )
        .collect()
    )


def to_dnf(node):
    """
    Recursively flattens a logic tree of arbitrary depth into a 2D matrix.
    Returns: List[List[int]] (Disjunctive Normal Form)
    """
    # base: no conjunctions
    if not isinstance(node, dict):
        return [[node]] if node is not None else []

    # extract logic & children
    conj = node.get("conj")
    children = node.get("items")
    if not children:
        return []

    # base: And/Or depth=1
    if all(isinstance(child, int) for child in children):
        if conj == "And":
            return [children]  # And(1, 2) -> [[1, 2]]
        else: 
            return [[x] for x in children]  # Or(1, 2) -> [[1], [2]]

    # recurse children to child matrices
    child_matrices = [to_dnf(child) for child in children]

    # DNF algorithm: apply associative property on Or(1, 2, Or(3))
    if conj == "Or":
        merged_matrix = []
        for matrix in child_matrices:
            merged_matrix.extend(matrix)
        return merged_matrix

    # DNF algorithm: apply distributive property (And over Or)
    #    (A OR B) AND (C OR D)
    # => (A AND (C OR D)) OR (B AND (C OR D))
    # => (A AND C) OR (A AND D) OR (B AND C) OR (B AND D)
    elif conj == "And":
        product = itertools.product(*child_matrices)
        
        merged_matrix = []
        for combination in product:
            new_clause = []
            for clause in combination:
                new_clause.extend(clause)
            merged_matrix.append(new_clause)
            
        return merged_matrix
    
    return []


@lru_cache(maxsize=128)
def _resolve_supertype(dtype1: pl.DataType, dtype2: pl.DataType) -> pl.DataType:
    """
    Caches the expensive supertype resolution.
    This bypasses creating dummy Series for repetitive primitive merges
    (e.g., merging Int64 and Float64 thousands of times).
    """
    try:
        # diagonal_relaxed allows Polars to determine the common supertype
        return pl.concat(
            [pl.Series([None], dtype=dtype1), pl.Series([None], dtype=dtype2)],
            how="diagonal_relaxed",
        ).dtype
    except Exception:
        raise TypeError(f"Could not merge incompatible types: {dtype1} and {dtype2}")
    return


def _merge_dtypes_optimized(dtype1: pl.DataType, dtype2: pl.DataType) -> pl.DataType:
    """Optimized recursive merge."""
    # 1. Identity Check (Fastest exit)
    if dtype1 == dtype2:
        return dtype1

    # 2. Null Handling
    if isinstance(dtype1, pl.Null): return dtype2
    if isinstance(dtype2, pl.Null): return dtype1

    # 3. Recursive List Merge
    if isinstance(dtype1, pl.List) and isinstance(dtype2, pl.List):
        return pl.List(_merge_dtypes_optimized(dtype1.inner, dtype2.inner))

    # 4. Recursive Struct Merge
    if isinstance(dtype1, pl.Struct) and isinstance(dtype2, pl.Struct):
        # Convert both to dictionaries once
        f1 = dtype1.to_schema()
        f2 = dtype2.to_schema()
        
        # Start with f1's fields
        merged_fields = f1.copy()
        
        # Only iterate over fields in f2
        for key, type2 in f2.items():
            type1 = merged_fields.get(key)
            if type1 is not None:
                # Recursively merge only if types differ
                if type1 != type2:
                    merged_fields[key] = _merge_dtypes_optimized(type1, type2)
            else:
                # New field from f2
                merged_fields[key] = type2
        
        return pl.Struct(merged_fields)

    # 5. Cached Primitive Resolution
    # We use the cached function for scalar types (Int, Float, String, etc.)
    return _resolve_supertype(dtype1, dtype2)


def merge_schemas(schemas: list[pl.Schema]) -> pl.Schema:
    """
    Optimized schema merging.
    """
    if not schemas:
        return pl.Schema()

    # Convert the first schema to a mutable dictionary immediately
    # casting to dict() is cheaper than repetitive lookups on a Schema object
    current_schema_map = dict(schemas[0])

    for schema in schemas[1:]:
        # Iterate only over the new schema's items
        for field_name, new_dtype in schema.items():
            existing_dtype = current_schema_map.get(field_name)
            
            if existing_dtype is None:
                # Fast path: New field
                current_schema_map[field_name] = new_dtype
            elif existing_dtype != new_dtype:
                # Slow path: Conflict resolution
                current_schema_map[field_name] = _merge_dtypes_optimized(existing_dtype, new_dtype)
    
    return pl.Schema(current_schema_map)

In [None]:
schema_list_prefix = [
    pl.read_json(
        source=fp,
        infer_schema_length=None
    ).schema 
    for fp 
    in PROJECTDIR.glob("data/*/*prefixes.json")
]
schema_prefix = merge_schemas(schema_list_prefix)
del schema_list_prefix

In [343]:
testpath = Path("/home/akash/Main/projects/CACourses/data/129/101to129-prefixes.json")

test = extract_articulations(fp=testpath, schema=schema_prefix)

dnf_matrix_df = test.with_columns(
    pl.col("articulation")
    .map_elements(to_dnf, return_dtype=pl.Object)
)

dnf_matrix_df.sort("course_id", "cc", "uni")

course_id,cc,uni,articulation
i64,i32,i32,object
25929,101,129,[[338761]]
25981,101,129,"[[41543], [336082, 41543]]"
27466,101,129,"[[309095, 287142], [309095, 304417, 287142], [309095]]"
27468,101,129,"[[251680, 287129], [251680]]"
27486,101,129,"[[304417], [287142, 304417], [309095, 304417, 287142]]"
…,…,…,…
359040,101,129,"[[307776], [280607, 307776], [307918, 359112, 307776]]"
359198,101,129,"[[359100], [359100, 358890]]"
359201,101,129,[[281073]]
359225,101,129,"[[358952], [280826, 358952]]"


In [345]:
prefixes_agg = pl.concat([
    extract_articulations(fp, schema_prefix).with_columns(
        pl.col("articulation")
        .map_elements(to_dnf, return_dtype=pl.List(pl.List(pl.Int64)))
    ).lazy()
    for fp
    in PROJECTDIR.glob("data/*/*-prefixes.json")
]).collect().unique()

prefixes_agg

course_id,cc,uni,articulation
i64,i32,i32,list[list[i64]]
259939,114,7,"[[201884, 207755]]"
353759,133,46,"[[205081, 212781], [204750, 212781]]"
377659,153,1,"[[372203], [372214]]"
174287,94,88,[[302608]]
99057,109,24,[[28143]]
…,…,…,…
73014,56,29,[[76319]]
85995,150,88,[[350038]]
353765,200,46,[[377121]]
288885,107,7,[[273944]]


In [330]:
import pickle

with (PROJECTDIR/"etl-pipeline/schema_prefix.pickle").open(mode="wb") as fp:
    pickle.dump(obj=schema_prefix, file=fp)