# Experiment 004: Numba-Accelerated SA with Lattice Construction

Implementing the jiweiliu kernel's approach:
1. Numba-accelerated geometry operations
2. SA optimization of lattice parameters
3. Backward propagation cascade

In [1]:
import math
import numpy as np
import pandas as pd
from numba import njit
from numba.typed import List as NumbaList
import time
import warnings
warnings.filterwarnings('ignore')

# Tree shape constants
TRUNK_W = 0.15
TRUNK_H = 0.2
BASE_W = 0.7
MID_W = 0.4
TOP_W = 0.25
TIP_Y = 0.8
TIER_1_Y = 0.5
TIER_2_Y = 0.25
BASE_Y = 0.0
TRUNK_BOTTOM_Y = -TRUNK_H

# Max distance for overlap check
MAX_OVERLAP_DIST_SQ = 1.8 * 1.8

In [2]:
# Numba-accelerated geometry functions
@njit(cache=True)
def rotate_point(x, y, cos_a, sin_a):
    return x * cos_a - y * sin_a, x * sin_a + y * cos_a

@njit(cache=True)
def get_tree_vertices(cx, cy, angle_deg):
    """Get 15 vertices of tree polygon at given position and angle."""
    angle_rad = angle_deg * math.pi / 180.0
    cos_a = math.cos(angle_rad)
    sin_a = math.sin(angle_rad)
    vertices = np.empty((15, 2), dtype=np.float64)
    pts = np.array([
        [0.0, TIP_Y],
        [TOP_W / 2.0, TIER_1_Y],
        [TOP_W / 4.0, TIER_1_Y],
        [MID_W / 2.0, TIER_2_Y],
        [MID_W / 4.0, TIER_2_Y],
        [BASE_W / 2.0, BASE_Y],
        [TRUNK_W / 2.0, BASE_Y],
        [TRUNK_W / 2.0, TRUNK_BOTTOM_Y],
        [-TRUNK_W / 2.0, TRUNK_BOTTOM_Y],
        [-TRUNK_W / 2.0, BASE_Y],
        [-BASE_W / 2.0, BASE_Y],
        [-MID_W / 4.0, TIER_2_Y],
        [-MID_W / 2.0, TIER_2_Y],
        [-TOP_W / 4.0, TIER_1_Y],
        [-TOP_W / 2.0, TIER_1_Y],
    ], dtype=np.float64)
    for i in range(15):
        rx, ry = rotate_point(pts[i, 0], pts[i, 1], cos_a, sin_a)
        vertices[i, 0] = rx + cx
        vertices[i, 1] = ry + cy
    return vertices

@njit(cache=True)
def polygon_bounds(vertices):
    min_x = vertices[0, 0]
    min_y = vertices[0, 1]
    max_x = vertices[0, 0]
    max_y = vertices[0, 1]
    for i in range(1, vertices.shape[0]):
        x, y = vertices[i, 0], vertices[i, 1]
        if x < min_x: min_x = x
        if x > max_x: max_x = x
        if y < min_y: min_y = y
        if y > max_y: max_y = y
    return min_x, min_y, max_x, max_y

@njit(cache=True)
def point_in_polygon(px, py, vertices):
    n = vertices.shape[0]
    inside = False
    j = n - 1
    for i in range(n):
        xi, yi = vertices[i, 0], vertices[i, 1]
        xj, yj = vertices[j, 0], vertices[j, 1]
        if ((yi > py) != (yj > py)) and (px < (xj - xi) * (py - yi) / (yj - yi) + xi):
            inside = not inside
        j = i
    return inside

@njit(cache=True)
def segments_intersect(p1x, p1y, p2x, p2y, p3x, p3y, p4x, p4y):
    dax, day = p2x - p1x, p2y - p1y
    dbx, dby = p4x - p3x, p4y - p3y
    d1x, d1y = p1x - p3x, p1y - p3y
    d2x, d2y = p2x - p3x, p2y - p3y
    cross_b1 = dbx * d1y - dby * d1x
    cross_b2 = dbx * d2y - dby * d2x
    if cross_b1 * cross_b2 > 0:
        return False
    d3x, d3y = p3x - p1x, p3y - p1y
    d4x, d4y = p4x - p1x, p4y - p1y
    cross_a1 = dax * d3y - day * d3x
    cross_a2 = dax * d4y - day * d4x
    if cross_a1 * cross_a2 > 0:
        return False
    return True

In [3]:
@njit(cache=True)
def polygons_overlap(verts1, verts2, cx1, cy1, cx2, cy2):
    """Check if two polygons overlap."""
    # Quick center distance check
    dx = cx2 - cx1
    dy = cy2 - cy1
    if dx * dx + dy * dy > MAX_OVERLAP_DIST_SQ:
        return False
    
    # Bounding box check
    min_x1, min_y1, max_x1, max_y1 = polygon_bounds(verts1)
    min_x2, min_y2, max_x2, max_y2 = polygon_bounds(verts2)
    if max_x1 < min_x2 or max_x2 < min_x1 or max_y1 < min_y2 or max_y2 < min_y1:
        return False
    
    # Check vertices inside
    for i in range(verts1.shape[0]):
        if point_in_polygon(verts1[i, 0], verts1[i, 1], verts2):
            return True
    for i in range(verts2.shape[0]):
        if point_in_polygon(verts2[i, 0], verts2[i, 1], verts1):
            return True
    
    # Check edge intersections
    n1, n2 = verts1.shape[0], verts2.shape[0]
    for i in range(n1):
        j = (i + 1) % n1
        for k in range(n2):
            m = (k + 1) % n2
            if segments_intersect(verts1[i,0], verts1[i,1], verts1[j,0], verts1[j,1],
                                  verts2[k,0], verts2[k,1], verts2[m,0], verts2[m,1]):
                return True
    return False

@njit(cache=True)
def has_any_overlap(all_vertices, centers_x, centers_y):
    n = len(all_vertices)
    for i in range(n):
        for j in range(i + 1, n):
            if polygons_overlap(all_vertices[i], all_vertices[j],
                              centers_x[i], centers_y[i], centers_x[j], centers_y[j]):
                return True
    return False

@njit(cache=True)
def get_side_length(all_vertices):
    min_x = math.inf
    min_y = math.inf
    max_x = -math.inf
    max_y = -math.inf
    for verts in all_vertices:
        x1, y1, x2, y2 = polygon_bounds(verts)
        if x1 < min_x: min_x = x1
        if y1 < min_y: min_y = y1
        if x2 > max_x: max_x = x2
        if y2 > max_y: max_y = y2
    return max(max_x - min_x, max_y - min_y)

@njit(cache=True)
def calculate_score_numba(all_vertices):
    side = get_side_length(all_vertices)
    return side * side / len(all_vertices)

In [4]:
# Load baseline and calculate score
def load_submission_data(filepath):
    df = pd.read_csv(filepath)
    all_xs, all_ys, all_degs = [], [], []
    for n in range(1, 201):
        prefix = f"{n:03d}_"
        group = df[df["id"].str.startswith(prefix)].sort_values("id")
        for _, row in group.iterrows():
            x = float(str(row["x"]).replace('s', ''))
            y = float(str(row["y"]).replace('s', ''))
            deg = float(str(row["deg"]).replace('s', ''))
            all_xs.append(x)
            all_ys.append(y)
            all_degs.append(deg)
    return np.array(all_xs), np.array(all_ys), np.array(all_degs)

def calculate_total_score(all_xs, all_ys, all_degs):
    total = 0.0
    idx = 0
    for n in range(1, 201):
        vertices = NumbaList()
        for i in range(n):
            vertices.append(get_tree_vertices(all_xs[idx + i], all_ys[idx + i], all_degs[idx + i]))
        score = calculate_score_numba(vertices)
        total += score
        idx += n
    return total

print("Loading baseline...")
baseline_xs, baseline_ys, baseline_degs = load_submission_data('/home/code/santa-2025-csv/santa-2025.csv')
print(f"Loaded {len(baseline_xs)} tree positions")

print("\nCalculating baseline score (this will compile Numba functions)...")
start = time.time()
baseline_score = calculate_total_score(baseline_xs, baseline_ys, baseline_degs)
print(f"Baseline score: {baseline_score:.6f} (took {time.time()-start:.1f}s)")

Loading baseline...


Loaded 20100 tree positions

Calculating baseline score (this will compile Numba functions)...


Baseline score: 70.676102 (took 1.1s)


In [None]:
# Backward propagation cascade using Numba
@njit(cache=True)
def deletion_cascade_numba(all_xs, all_ys, all_degs):
    """Apply tree deletion cascade from N=200 down to N=2."""
    # Build index mapping
    group_start = np.zeros(201, dtype=np.int64)
    for n in range(1, 201):
        group_start[n] = group_start[n-1] + (n - 1) if n > 1 else 0
    
    # Copy arrays
    new_xs = all_xs.copy()
    new_ys = all_ys.copy()
    new_degs = all_degs.copy()
    
    # Calculate initial side lengths
    side_lengths = np.zeros(201, dtype=np.float64)
    for n in range(1, 201):
        start = group_start[n]
        vertices = NumbaList()
        for i in range(n):
            vertices.append(get_tree_vertices(new_xs[start + i], new_ys[start + i], new_degs[start + i]))
        side_lengths[n] = get_side_length(vertices)
    
    improvements = 0
    
    # Cascade from n=200 down to n=2
    for n in range(200, 1, -1):
        start_n = group_start[n]
        start_prev = group_start[n - 1]
        
        best_side = side_lengths[n - 1]
        best_delete_idx = -1
        
        for del_idx in range(n):
            vertices = NumbaList()
            for i in range(n):
                if i != del_idx:
                    idx = start_n + i
                    vertices.append(get_tree_vertices(new_xs[idx], new_ys[idx], new_degs[idx]))
            
            candidate_side = get_side_length(vertices)
            if candidate_side < best_side - 1e-10:
                best_side = candidate_side
                best_delete_idx = del_idx
        
        if best_delete_idx >= 0:
            out_idx = start_prev
            for i in range(n):
                if i != best_delete_idx:
                    in_idx = start_n + i
                    new_xs[out_idx] = new_xs[in_idx]
                    new_ys[out_idx] = new_ys[in_idx]
                    new_degs[out_idx] = new_degs[in_idx]
                    out_idx += 1
            side_lengths[n - 1] = best_side
            improvements += 1
    
    return new_xs, new_ys, new_degs, side_lengths, improvements

print("\nRunning backward propagation cascade...")
start = time.time()
new_xs, new_ys, new_degs, side_lengths, improvements = deletion_cascade_numba(baseline_xs, baseline_ys, baseline_degs)
print(f"Completed in {time.time()-start:.1f}s")
print(f"Improvements found: {improvements}")

# Calculate new score
new_score = calculate_total_score(new_xs, new_ys, new_degs)
print(f"\nBaseline score: {baseline_score:.6f}")
print(f"After cascade: {new_score:.6f}")
print(f"Improvement: {baseline_score - new_score:.6f}")

In [None]:
# Save submission
def save_submission(filepath, all_xs, all_ys, all_degs):
    rows = []
    idx = 0
    for n in range(1, 201):
        for t in range(n):
            rows.append({
                "id": f"{n:03d}_{t}",
                "x": f"s{all_xs[idx]}",
                "y": f"s{all_ys[idx]}",
                "deg": f"s{all_degs[idx]}",
            })
            idx += 1
    df = pd.DataFrame(rows)
    df.to_csv(filepath, index=False)
    return df

# Use the better solution
if new_score < baseline_score:
    final_xs, final_ys, final_degs = new_xs, new_ys, new_degs
    final_score = new_score
else:
    final_xs, final_ys, final_degs = baseline_xs, baseline_ys, baseline_degs
    final_score = baseline_score

print(f"\nFinal score: {final_score:.6f}")
print(f"Target: 68.922808")
print(f"Gap: {final_score - 68.922808:.6f}")

# Save
save_submission('/home/submission/submission.csv', final_xs, final_ys, final_degs)
save_submission('/home/code/experiments/004_numba_sa/submission.csv', final_xs, final_ys, final_degs)
print("\nSubmission saved!")

In [None]:
# Save metrics
import json
metrics = {
    'baseline_score': baseline_score,
    'final_score': final_score,
    'improvement': baseline_score - final_score,
    'cascade_improvements': int(improvements),
    'target': 68.922808
}
with open('/home/code/experiments/004_numba_sa/metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2)
print("Metrics saved")