In [None]:
import os, shutil, time, random, subprocess
import numpy as np
import pandas as pd
from numba import njit
import math

# ---------------------------
# 0) Paths / setup
# ---------------------------
INPUT_SUB = "/kaggle/input/why-not/submission.csv"
INPUT_BIN = "/kaggle/input/a-bit-better/bbox3"
WORKING_DIR = "/kaggle/working"
SUB_PATH = os.path.join(WORKING_DIR, "submission.csv")
BIN_PATH = os.path.join(WORKING_DIR, "bbox3")

def setup_environment():
    os.makedirs(WORKING_DIR, exist_ok=True)

    if os.path.exists(INPUT_SUB):
        shutil.copy(INPUT_SUB, SUB_PATH)
        print(f"✅ Copied {INPUT_SUB} -> {SUB_PATH}")
    else:
        raise FileNotFoundError(INPUT_SUB)

    if os.path.exists(INPUT_BIN):
        shutil.copy(INPUT_BIN, BIN_PATH)
        print(f"✅ Copied {INPUT_BIN} -> {BIN_PATH}")
    else:
        raise FileNotFoundError(INPUT_BIN)

    os.chmod(BIN_PATH, 0o755)
    print("✅ Permissions set for bbox3")

setup_environment()


# ---------------------------
# 1) Fast scoring (hot loop)
#    - No shapely
#    - Uses your polygon template + bbox across all trees in a group
# ---------------------------

@njit
def make_polygon_template():
    tw=0.15; th=0.2; bw=0.7; mw=0.4; ow=0.25
    tip=0.8; t1=0.5; t2=0.25; base=0.0; tbot=-th
    x=np.array([0,ow/2,ow/4,mw/2,mw/4,bw/2,tw/2,tw/2,-tw/2,-tw/2,-bw/2,-mw/4,-mw/2,-ow/4,-ow/2],np.float64)
    y=np.array([tip,t1,t1,t2,t2,base,base,tbot,tbot,base,base,t2,t2,t1,t1],np.float64)
    return x,y

TX, TY = make_polygon_template()

@njit
def score_group(xs, ys, degs, tx, ty):
    """
    Returns (side^2)/n where side is max width/height of union AABB of all rotated templates.
    """
    n = xs.size
    V = tx.size

    mnx = 1e300
    mny = 1e300
    mxx = -1e300
    mxy = -1e300

    for i in range(n):
        r = degs[i] * math.pi / 180.0
        c = math.cos(r)
        s = math.sin(r)
        xi = xs[i]
        yi = ys[i]
        for j in range(V):
            X = c * tx[j] - s * ty[j] + xi
            Y = s * tx[j] + c * ty[j] + yi
            if X < mnx: mnx = X
            if X > mxx: mxx = X
            if Y < mny: mny = Y
            if Y > mxy: mxy = Y

    side = mxx - mnx
    h = mxy - mny
    if h > side:
        side = h

    return (side * side) / n

def _strip_s_fast(arr: pd.Series) -> np.ndarray:
    # Expect strings like "s0.123"
    # Using pandas vector ops is faster than Python loops.
    return arr.str.slice(1).astype(np.float64).to_numpy()

def fast_total_score(df: pd.DataFrame) -> float:
    """
    Scores whole submission quickly:
    - derive group id from first 3 chars of id (e.g., '001_...')
    - sum score_group over groups 1..200
    """
    # Ensure expected columns exist
    for col in ("id", "x", "y", "deg"):
        if col not in df.columns:
            raise ValueError(f"Missing column: {col}")

    # group key as int: 1..200
    g = df["id"].str.slice(0, 3).astype(np.int16)
    df2 = df.copy()
    df2["g"] = g

    total = 0.0
    # groupby is much faster than repeatedly filtering with startswith
    for gi, block in df2.groupby("g", sort=False):
        # Only score groups 1..200 (ignore anything else if present)
        if gi < 1 or gi > 200:
            continue
        xs = _strip_s_fast(block["x"])
        ys = _strip_s_fast(block["y"])
        degs = _strip_s_fast(block["deg"])
        total += float(score_group(xs, ys, degs, TX, TY))
    return total


# ---------------------------
# 2) Optional overlap audit (slow) - only run when improving
#    Keep your shapely method, but don't do it every iteration.
# ---------------------------
def overlap_audit_if_needed(df: pd.DataFrame, do_audit: bool) -> bool:
    """
    Returns True if overlaps are detected, False otherwise.
    If do_audit==False, always returns False (skips).
    """
    if not do_audit:
        return False

    from decimal import Decimal, getcontext
    from shapely import affinity
    from shapely.geometry import Polygon
    from shapely.strtree import STRtree

    getcontext().prec = 25
    scale_factor = Decimal("1e18")

    class ChristmasTree:
        def __init__(self, center_x="0", center_y="0", angle="0"):
            self.center_x = Decimal(center_x)
            self.center_y = Decimal(center_y)
            self.angle = Decimal(angle)

            trunk_w = Decimal("0.15")
            trunk_h = Decimal("0.2")
            base_w  = Decimal("0.7")
            mid_w   = Decimal("0.4")
            top_w   = Decimal("0.25")
            tip_y   = Decimal("0.8")
            t1_y    = Decimal("0.5")
            t2_y    = Decimal("0.25")
            base_y  = Decimal("0.0")
            tbot_y  = -trunk_h

            initial_polygon = Polygon([
                (Decimal("0.0")*scale_factor, tip_y*scale_factor),
                (top_w/Decimal("2")*scale_factor, t1_y*scale_factor),
                (top_w/Decimal("4")*scale_factor, t1_y*scale_factor),
                (mid_w/Decimal("2")*scale_factor, t2_y*scale_factor),
                (mid_w/Decimal("4")*scale_factor, t2_y*scale_factor),
                (base_w/Decimal("2")*scale_factor, base_y*scale_factor),
                (trunk_w/Decimal("2")*scale_factor, base_y*scale_factor),
                (trunk_w/Decimal("2")*scale_factor, tbot_y*scale_factor),
                (-(trunk_w/Decimal("2"))*scale_factor, tbot_y*scale_factor),
                (-(trunk_w/Decimal("2"))*scale_factor, base_y*scale_factor),
                (-(base_w/Decimal("2"))*scale_factor, base_y*scale_factor),
                (-(mid_w/Decimal("4"))*scale_factor, t2_y*scale_factor),
                (-(mid_w/Decimal("2"))*scale_factor, t2_y*scale_factor),
                (-(top_w/Decimal("4"))*scale_factor, t1_y*scale_factor),
                (-(top_w/Decimal("2"))*scale_factor, t1_y*scale_factor),
            ])

            rotated = affinity.rotate(initial_polygon, float(self.angle), origin=(0, 0))
            self.polygon = affinity.translate(
                rotated,
                xoff=float(self.center_x * scale_factor),
                yoff=float(self.center_y * scale_factor)
            )

    def load_group_trees(group_id: int, dff: pd.DataFrame):
        key = f"{group_id:03d}_"
        block = dff[dff["id"].str.startswith(key)]
        trees = []
        for _, row in block.iterrows():
            trees.append(ChristmasTree(row["x"][1:], row["y"][1:], row["deg"][1:]))
        return trees

    def has_overlap(polys):
        if len(polys) <= 1:
            return False
        idx = STRtree(polys)
        for i, p in enumerate(polys):
            hits = idx.query(p)
            for j in hits:
                if j == i:
                    continue
                if p.intersects(polys[j]) and not p.touches(polys[j]):
                    return True
        return False

    bad = []
    for gi in range(1, 201):
        trees = load_group_trees(gi, df)
        polys = [t.polygon for t in trees]
        if has_overlap(polys):
            bad.append(gi)

    if bad:
        print("❌ Overlap detected in groups:", bad[:20], ("..." if len(bad) > 20 else ""))
        return True

    print("✅ Overlap audit passed")
    return False


# ---------------------------
# 3) Search strategy around bbox3
#    - Keep best
#    - Explore then exploit
# ---------------------------

BEST_BACKUP = os.path.join(WORKING_DIR, "best_submission.csv")

def read_submission() -> pd.DataFrame:
    return pd.read_csv(SUB_PATH)

def save_best():
    shutil.copy(SUB_PATH, BEST_BACKUP)

def restore_best():
    shutil.copy(BEST_BACKUP, SUB_PATH)

def run_bbox3(n: int, r: int):
    # bbox3 typically overwrites submission.csv in cwd; enforce cwd
    subprocess.run(
        [BIN_PATH, "-n", str(n), "-r", str(r)],
        cwd=WORKING_DIR,
        check=True,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    )

def clamp(v, lo, hi):
    return lo if v < lo else hi if v > hi else v

def propose_candidates(best_n, best_r, phase, k=40):
    """
    phase 0: coarse explore
    phase 1: local refine
    """
    cand = []
    if phase == 0:
        # Wide exploration: cover space
        for _ in range(k):
            n = random.randint(50, 500)
            r = random.randint(10, 80)
            cand.append((n, r))
    else:
        # Local refinement: try around best
        for _ in range(k):
            dn = int(random.gauss(0, 40))
            dr = int(random.gauss(0, 8))
            n = clamp(best_n + dn, 50, 500)
            r = clamp(best_r + dr, 10, 80)
            cand.append((n, r))
        # Also add a few “temperature schedule” points
        cand += [
            (best_n, clamp(best_r + 1, 10, 80)),
            (best_n, clamp(best_r - 1, 10, 80)),
            (clamp(best_n + 25, 50, 500), best_r),
            (clamp(best_n - 25, 50, 500), best_r),
        ]
    # Deduplicate while preserving order
    seen = set()
    out = []
    for x in cand:
        if x not in seen:
            out.append(x)
            seen.add(x)
    return out

def optimize(time_budget_sec=25*60, audit_on_improve=False, max_iters=None):
    """
    time_budget_sec: stop after this many seconds
    audit_on_improve: run slow overlap check only when we get a new best
    max_iters: optional hard cap
    """
    random.seed()

    df0 = read_submission()
    best_score = fast_total_score(df0)
    best_n, best_r = 465, 50  # just a starting anchor; not used until after first improvement

    save_best()
    print(f"Initial score: {best_score:.12f}")

    start = time.time()
    it = 0
    phase = 0

    while True:
        if max_iters is not None and it >= max_iters:
            break
        if time.time() - start > time_budget_sec:
            break

        # switch to refinement after some iterations
        if it == 120:
            phase = 1

        candidates = propose_candidates(best_n, best_r, phase, k=30 if phase == 0 else 40)

        for n, r in candidates:
            it += 1
            if max_iters is not None and it >= max_iters:
                break
            if time.time() - start > time_budget_sec:
                break

            # Run bbox3
            try:
                run_bbox3(n, r)
            except subprocess.CalledProcessError:
                # If bbox3 fails, rollback and continue
                restore_best()
                continue

            # Score quickly
            df = read_submission()
            sc = fast_total_score(df)

            if sc < best_score:
                # Optional correctness audit (slow)
                if overlap_audit_if_needed(df, audit_on_improve):
                    # invalid -> rollback
                    restore_best()
                    continue

                # Accept improvement
                best_score = sc
                best_n, best_r = n, r
                save_best()
                print(f"✅ New best {best_score:.12f}   (n={best_n}, r={best_r}, iter={it})")
            else:
                # Reject and rollback so we always work from the best-known state
                restore_best()

    # Make sure final submission.csv is best
    restore_best()
    print(f"Done. Best score: {best_score:.12f}   best (n={best_n}, r={best_r})")
    return best_score, best_n, best_r

# Run optimization
# - Increase time_budget_sec if you have time; more iterations usually = better score.
best_score, bn, br = optimize(
    time_budget_sec=1*60*60,      # adjust to your runtime budget
    audit_on_improve=True,     # set True if you suspect bbox3 might create overlaps
    max_iters=None
)