# [Santa25] Improved simulated annealing with translations

In [the previous notebook](https://www.kaggle.com/code/egortrushin/santa25-simulated-annealing-with-translations), the implementation of simulated annealing to deal with periodic tree structures was introduced.

In this notebook, the refined version is presented. Three main improvements/changes in the new version:

- Translations are determined via simulated annealing updates rather than being determined through finding minimal translations between unit cells.
- Simulated annealing update is added which rotates all trees within the unit cell by the same angle.
- Option is added to translate only one tree during last translation in given direction. This extends the variety of puzzles which can be solved using the present approach.

In this notebook, I didn't try to get the best possible scores. With provided code, one can easily improve the scores of the given solutions as well as find solutions for some other $N$ which are better than the current public ones.

In [None]:
import datetime
import json
import copy
import random
import math
import time
import sys
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from decimal import Decimal, getcontext
from shapely import affinity
from shapely.geometry import Polygon
from shapely.ops import unary_union
from matplotlib.patches import Rectangle

getcontext().prec = 25
scale_factor = Decimal("1e15")

In [None]:
class ChristmasTree:
    """Represents a single, rotatable Christmas tree of a fixed size."""

    def __init__(self, center_x="0", center_y="0", angle="0"):
        """Initializes the Christmas tree with a specific position and rotation."""
        self.center_x = Decimal(center_x)
        self.center_y = Decimal(center_y)
        self.angle = Decimal(angle)

        trunk_w = Decimal("0.15")
        trunk_h = Decimal("0.2")
        base_w = Decimal("0.7")
        mid_w = Decimal("0.4")
        top_w = Decimal("0.25")
        tip_y = Decimal("0.8")
        tier_1_y = Decimal("0.5")
        tier_2_y = Decimal("0.25")
        base_y = Decimal("0.0")
        trunk_bottom_y = -trunk_h

        initial_polygon = Polygon(
            [
                # Start at Tip
                (Decimal("0.0") * scale_factor, tip_y * scale_factor),
                # Right side - Top Tier
                (top_w / Decimal("2") * scale_factor, tier_1_y * scale_factor),
                (top_w / Decimal("4") * scale_factor, tier_1_y * scale_factor),
                # Right side - Middle Tier
                (mid_w / Decimal("2") * scale_factor, tier_2_y * scale_factor),
                (mid_w / Decimal("4") * scale_factor, tier_2_y * scale_factor),
                # Right side - Bottom Tier
                (base_w / Decimal("2") * scale_factor, base_y * scale_factor),
                # Right Trunk
                (trunk_w / Decimal("2") * scale_factor, base_y * scale_factor),
                (trunk_w / Decimal("2") * scale_factor, trunk_bottom_y * scale_factor),
                # Left Trunk
                (-(trunk_w / Decimal("2")) * scale_factor, trunk_bottom_y * scale_factor),
                (-(trunk_w / Decimal("2")) * scale_factor, base_y * scale_factor),
                # Left side - Bottom Tier
                (-(base_w / Decimal("2")) * scale_factor, base_y * scale_factor),
                # Left side - Middle Tier
                (-(mid_w / Decimal("4")) * scale_factor, tier_2_y * scale_factor),
                (-(mid_w / Decimal("2")) * scale_factor, tier_2_y * scale_factor),
                # Left side - Top Tier
                (-(top_w / Decimal("4")) * scale_factor, tier_1_y * scale_factor),
                (-(top_w / Decimal("2")) * scale_factor, tier_1_y * scale_factor),
            ]
        )
        rotated = affinity.rotate(initial_polygon, float(self.angle), origin=(0, 0))
        self.polygon = affinity.translate(
            rotated, xoff=float(self.center_x * scale_factor), yoff=float(self.center_y * scale_factor)
        )

    def get_params(self):
        return self.center_x, self.center_y, self.angle

    def set_params(self, center_x, center_y, angle):
        self.__init__(str(center_x), str(center_y), str(angle))

    def clone(self):
        """Create a deep copy of the tree."""
        return ChristmasTree(str(self.center_x), str(self.center_y), str(self.angle))

In [None]:
def format_time(elapsed):
    """Take a time in seconds and return a string hh:mm:ss."""
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [None]:
def plot_trees(placed_trees):
    """Plots the arrangement of trees and the bounding square."""
    _, ax = plt.subplots(figsize=(6, 6))
    colors = plt.cm.viridis([i / len(placed_trees) for i in range(len(placed_trees))])

    all_polygons = [t.polygon for t in placed_trees]
    bounds = unary_union(all_polygons).bounds

    for i, tree in enumerate(placed_trees):
        # Rescale for plotting
        x_scaled, y_scaled = tree.polygon.exterior.xy
        x = [Decimal(val) / scale_factor for val in x_scaled]
        y = [Decimal(val) / scale_factor for val in y_scaled]
        ax.plot(x, y, color=colors[i])
        ax.fill(x, y, alpha=0.5, color=colors[i])
        ax.text(float(tree.center_x), float(tree.center_y), str(i))

    minx = Decimal(bounds[0]) / scale_factor
    miny = Decimal(bounds[1]) / scale_factor
    maxx = Decimal(bounds[2]) / scale_factor
    maxy = Decimal(bounds[3]) / scale_factor

    width = maxx - minx
    height = maxy - miny
    side_length = max(width, height)

    square_x = minx if width >= height else minx - (side_length - width) / 2
    square_y = miny if height >= width else miny - (side_length - height) / 2
    bounding_square = Rectangle(
        (float(square_x), float(square_y)),
        float(side_length),
        float(side_length),
        fill=False,
        edgecolor="red",
        linewidth=2,
        linestyle="--",
    )
    ax.add_patch(bounding_square)

    padding = 0.1
    ax.set_xlim(float(square_x - Decimal(str(padding))), float(square_x + side_length + Decimal(str(padding))))
    ax.set_ylim(float(square_y - Decimal(str(padding))), float(square_y + side_length + Decimal(str(padding))))
    ax.set_aspect("equal", adjustable="box")
    ax.axis("off")
    score = side_length**2 / len(placed_trees)
    plt.title(f"N = {len(placed_trees)} Score = {score:.6f}\n W = {width:.6f}  H = {height:.6f}")
    plt.show()

In [None]:
def load_configuration_from_df(n, existing_df):
    """
    Load existing configuration from submission CSV.
    """
    group_data = existing_df[existing_df["id"].str.startswith(f"{n:03d}_")]
    trees = []
    for _, row in group_data.iterrows():
        x = row["x"][1:]  # Remove 's' prefix
        y = row["y"][1:]
        deg = row["deg"][1:]
        trees.append(ChristmasTree(x, y, deg))
    if len(trees) != n:
        raise RuntimeError("Number of trees is inconsistent")
    return trees

In [None]:
def to_str(x: Decimal):
    return f"s{float(x)}"

In [None]:
class SimulatedAnnealing:
    def __init__(
        self,
        trees,
        a,
        b,
        nt,
        append_x,
        append_y,
        Tmax,
        Tmin,
        nsteps,
        nsteps_per_T,
        cooling,
        alpha,
        n,
        position_delta,
        angle_delta,
        angle_delta2,
        delta_t,
        random_state,
        log_freq,
    ):
        self.trees = trees
        self.a = a
        self.b = b
        self.nt = nt
        self.append_x = append_x
        self.append_y = append_y
        self.Tmax = Tmax
        self.Tmin = Tmin
        self.nsteps = nsteps
        self.nsteps_per_T = nsteps_per_T
        self.cooling = cooling
        self.alpha = alpha
        self.n = n
        self.position_delta = position_delta
        self.angle_delta = angle_delta
        self.angle_delta2 = angle_delta2
        self.delta_t = delta_t
        self.log_freq = log_freq
        random.seed(random_state)

    def perturb_tree(self, tree):
        """Perturb tree position and angle"""
        old_x, old_y, old_angle = tree.get_params()
        dx = Decimal(str(random.uniform(-self.position_delta, self.position_delta)))
        dy = Decimal(str(random.uniform(-self.position_delta, self.position_delta)))
        dangle = Decimal(str(random.uniform(-self.angle_delta, self.angle_delta)))
        new_x = old_x + dx
        new_y = old_y + dy
        new_angle = (old_angle + dangle) % 360
        tree.set_params(new_x, new_y, new_angle)
        return old_x, old_y, old_angle

    def get_lengths(self, current_trees):
        xys = np.concatenate([np.asarray(t.polygon.exterior.xy).T / 1e15 for t in current_trees])
        min_x, min_y = xys.min(axis=0)
        max_x, max_y = xys.max(axis=0)
        return max_x - min_x, max_y - min_y

    def has_overlap(self, trees, n=None):
        """Check for overlap between trees."""
        if len(trees) <= 1:
            return False
        if n is None:
            for i, tree1 in enumerate(trees):
                for j, tree2 in enumerate(trees):
                    if i < j:
                        if tree1.polygon.intersects(tree2.polygon) and not tree1.polygon.touches(tree2.polygon):
                            return True
        else:
            for i, tree1 in enumerate(trees):
                if i != n:
                    if tree1.polygon.intersects(trees[n].polygon) and not tree1.polygon.touches(trees[n].polygon):
                        return True
        return False

    def _acceptance_probability(self, current_energy, new_energy, temperature):
        """Calculate the probability of accepting a new solution."""
        if new_energy < current_energy:
            return 1.0
        return math.exp((current_energy - new_energy) / temperature)
    
    def perturb_translations(self, a, b):
        """Perturb tree position and angle"""
        old_a = copy.copy(a)
        old_b = copy.copy(b)
        da = random.uniform(-self.delta_t, self.delta_t)
        db = random.uniform(-self.delta_t, self.delta_t)
        new_a = old_a + old_a * da
        new_b = old_b + old_b * db
        return new_a, new_b, old_a, old_b

    def rotate_all(self, trees):
        """Perturb trees angle"""
        old_angles = []
        dangle = Decimal(str(random.uniform(-self.angle_delta2, self.angle_delta2)))
        for tree in trees:
            x, y, old_angle = tree.get_params()
            old_angles.append(old_angle)
            new_angle = (old_angle + dangle) % 360
            tree.set_params(x, y, new_angle)
        return trees, old_angles

    def translate(self, primitive_trees, a, b, nt, append_x=False, append_y=False):
        lattice_trees = []
        for tree in primitive_trees:
            for x in range(nt[0]):
                for y in range(nt[1]):
                    lattice_trees.append(
                        ChristmasTree(
                            center_x=tree.center_x + Decimal(x * a),
                            center_y=tree.center_y + Decimal(y * b),
                            angle=tree.angle,
                        )
                    )
        if append_x:
            for y in range(nt[1]):
                lattice_trees.append(
                    ChristmasTree(
                        center_x=primitive_trees[1].center_x + Decimal(nt[0] * a),
                        center_y=primitive_trees[1].center_y + Decimal(y * b),
                        angle=primitive_trees[1].angle,
                    )
                )
        if append_y:
            for x in range(nt[0]):
                lattice_trees.append(
                    ChristmasTree(
                        center_x=primitive_trees[1].center_x + Decimal(x * a),
                        center_y=primitive_trees[1].center_y + Decimal(nt[1] * b),
                        angle=primitive_trees[1].angle,
                    )
                )

        return lattice_trees

    def save_structure(self, trees, a, b):
        trees_list = []
        for tree in trees:
            trees_list.append(list(map(float, tree.get_params())))

        structure = {
            "primitive_trees": trees_list,
            "a": a,
            "b": b,
            "nt": self.nt,
            "append_x": self.append_x,
            "append_y": self.append_y,
        }

        with open("best_structure.json", "w", encoding="utf-8") as outfile:
            json.dump(structure, outfile, indent=4)

    def solve(self):

        t0 = time.time()  # Measure staring time

        T = self.Tmax

        primitive_trees = [tree.clone() for tree in self.trees]

        if self.a is None:
            a, b = self.get_lengths(primitive_trees)
        else:
            a, b = copy.copy(self.a), copy.copy(self.b)

        lattice_trees = self.translate(primitive_trees, a, b, self.nt, self.append_x, self.append_y)
        if self.has_overlap(lattice_trees):
            print("Initial tree configuration has overlap!")
            sys.exit()

        current_score = max(self.get_lengths(lattice_trees)) ** 2 / len(lattice_trees)
        best_trees = [tree.clone() for tree in primitive_trees]
        best_a, best_b = copy.copy(a), copy.copy(b)
        best_score = current_score

        for step in range(self.nsteps):
            for step1 in range(self.nsteps_per_T):
                i = random.randint(0, len(primitive_trees) + 1)
                if i < len(primitive_trees):
                    old_params = self.perturb_tree(primitive_trees[i])
                elif i == len(primitive_trees):
                    a, b, old_a, old_b = self.perturb_translations(a, b)
                else:
                    primitive_trees, old_angles = self.rotate_all(primitive_trees)

                if self.has_overlap(self.translate(primitive_trees, a, b, [2, 2])):
                    if i < len(primitive_trees):
                        primitive_trees[i].set_params(*old_params)
                    elif i == len(primitive_trees):
                        a = old_a
                        b = old_b
                    else:
                        for i, tree in enumerate(primitive_trees):
                            x, y, _ = tree.get_params()
                            tree.set_params(x, y, old_angles[i])

                    if step1 % self.log_freq == 0 or step1 == (self.nsteps_per_T - 1):
                        t1 = format_time(time.time() - t0)
                        print(
                            f"T: {T:.3e}  Step: {step1:6}  Score: {current_score:8.5f}  Best score: {best_score:8.5f}  Elapsed Time: {t1}",
                            flush=True,
                        )
                    continue
                else:
                    lattice_trees = self.translate(primitive_trees, a, b, self.nt, self.append_x, self.append_y)

                new_score = max(self.get_lengths(lattice_trees)) ** 2 / len(lattice_trees)
                acceptance = self._acceptance_probability(current_score, new_score, T)

                if acceptance > random.random():
                    current_score = new_score
                    if new_score < best_score:
                        best_score = new_score
                        best_trees = [tree.clone() for tree in lattice_trees]
                        side_x, side_y = self.get_lengths(best_trees)
                        # print(
                        #     f"NEW BEST SCORE: {best_score:.5f}  X: {side_x:.5f}  Y: {side_y:.5f}  a: {a:.5f}  b: {b:.5f}"
                        # )
                        best_a, best_b = copy.copy(a), copy.copy(b)
                        self.save_structure(primitive_trees, a, b)
                else:
                    if i < len(primitive_trees):
                        primitive_trees[i].set_params(*old_params)
                    elif i == len(primitive_trees):
                        a = old_a
                        b = old_b
                    else:
                        for i, tree in enumerate(primitive_trees):
                            x, y, _ = tree.get_params()
                            tree.set_params(x, y, old_angles[i])

                if step1 % self.log_freq == 0 or step1 == (self.nsteps_per_T - 1):
                    t1 = format_time(time.time() - t0)
                    print(
                        f"T: {T:.3e}  Step: {step1:6}  Score: {current_score:8.5f}  Best score: {best_score:8.5f}  Elapsed Time: {t1}",
                        flush=True,
                    )

            # lower the temperature
            if self.cooling == "linear":
                T -= (self.Tmax - self.Tmin) / self.nsteps
            elif self.cooling == "exponential":
                Tfactor = -math.log(self.Tmax / self.Tmin)
                T = self.Tmax * math.exp(Tfactor * (step + 1) / self.nsteps)
            elif self.cooling == "polynomial":
                T = self.Tmin + (self.Tmax - self.Tmin) * ((self.nsteps - step - 1) / self.nsteps) ** self.n

        return best_score, best_trees

In [None]:
initial_trees = []
for x, y, deg in [[-4.191683864412409, -4.498489528496051, 74.54421568660419], [-4.92202045352307, -4.727639556649786, 254.5401905706735]]:
    initial_trees.append(ChristmasTree(x, y, deg))

plot_trees(initial_trees)

In [None]:
%%writefile config.yaml

params:
    nt: [3, 5]
    a: 0.8744896974945239
    b: 0.7499641699190263
    append_x: False
    append_y: False
    Tmax: 0.001
    Tmin: 0.000001
    alpha: 0.99
    nsteps: 10
    nsteps_per_T: 10000
    cooling: 'exponential'
    alpha: 0.99
    n: 4
    position_delta: 0.002
    angle_delta: 1.
    angle_delta2: 1.
    delta_t: 0.002
    random_state: 42
    log_freq: 5000

In [None]:
with open("config.yaml", "r") as file_obj:
    config = yaml.safe_load(file_obj)

sa = SimulatedAnnealing(initial_trees, **config["params"])
score, trees = sa.solve()

plot_trees(trees)

new_trees = {}
new_trees[len(trees)] = trees

In [None]:
config["params"]["nt"] = [4, 5]

sa = SimulatedAnnealing(initial_trees, **config["params"])
score, trees = sa.solve()

plot_trees(trees)

new_trees[len(trees)] = trees

In [None]:
config["params"]["nt"] = [4, 6]

sa = SimulatedAnnealing(initial_trees, **config["params"])
score, trees = sa.solve()

plot_trees(trees)

new_trees[len(trees)] = trees

In [None]:
config["params"]["nt"] = [4, 7]

sa = SimulatedAnnealing(initial_trees, **config["params"])
score, trees = sa.solve()

plot_trees(trees)

new_trees[len(trees)] = trees

In [None]:
config["params"]["nt"] = [5, 7]
config["params"]["append_y"] = True

sa = SimulatedAnnealing(initial_trees, **config["params"])
score, trees = sa.solve()

plot_trees(trees)

new_trees[len(trees)] = trees

In [None]:
config["params"]["nt"] = [5, 8]
config["params"]["append_y"] = False

sa = SimulatedAnnealing(initial_trees, **config["params"])
score, trees = sa.solve()

plot_trees(trees)

new_trees[len(trees)] = trees

In [None]:
config["params"]["nt"] = [6, 7]
config["params"]["append_y"] = False

sa = SimulatedAnnealing(initial_trees, **config["params"])
score, trees = sa.solve()

plot_trees(trees)

new_trees[len(trees)] = trees

In [None]:
config["params"]["nt"] = [7, 11]
config["params"]["append_y"] = True

sa = SimulatedAnnealing(initial_trees, **config["params"])
score, trees = sa.solve()

plot_trees(trees)

new_trees[len(trees)] = trees

In [None]:
config["params"]["nt"] = [8, 12]
config["params"]["append_y"] = True

sa = SimulatedAnnealing(initial_trees, **config["params"])
score, trees = sa.solve()

plot_trees(trees)

new_trees[len(trees)] = trees

In [None]:
df = pd.read_csv("/kaggle/input/why-not/submission.csv")

rows = []
for n in range(1, 201):
    trees = load_configuration_from_df(n, df)
    if n in new_trees:
        for i_t, tree in enumerate(new_trees[n]):
            rows.append(
                {
                    "id": f"{n:03d}_{i_t}",
                    "x": to_str(tree.center_x),
                    "y": to_str(tree.center_y),
                    "deg": to_str(tree.angle),
                }
            )
    else:
        for i_t, tree in enumerate(trees):
            rows.append(
                {
                    "id": f"{n:03d}_{i_t}",
                    "x": to_str(tree.center_x),
                    "y": to_str(tree.center_y),
                    "deg": to_str(tree.angle),
                }
            )

df = pd.DataFrame(rows)
df.to_csv("submission.csv", index=False)

In [None]:
def get_tree_list_side_lenght(tree_list: list[ChristmasTree]) -> Decimal:
    all_polygons = [t.polygon for t in tree_list]
    bounds = unary_union(all_polygons).bounds
    return Decimal(max(bounds[2] - bounds[0], bounds[3] - bounds[1])) / scale_factor

def get_total_score(dict_of_side_length: dict[str, Decimal]):
    score = 0
    for k, v in dict_of_side_length.items():
        score += v ** 2 / Decimal(k)
    return score

def parse_csv(csv_path) -> dict[str, list[ChristmasTree]]:

    result = pd.read_csv(csv_path)
    result['x'] = result['x'].str.strip('s')
    result['y'] = result['y'].str.strip('s')
    result['deg'] = result['deg'].str.strip('s')
    result[['group_id', 'item_id']] = result['id'].str.split('_', n=2, expand=True)

    dict_of_tree_list = {}
    dict_of_side_length = {}
    for group_id, group_data in result.groupby('group_id'):
        tree_list = [ChristmasTree(center_x=row['x'], center_y=row['y'], angle=row['deg']) for _, row in group_data.iterrows()]
        dict_of_tree_list[group_id] = tree_list
        dict_of_side_length[group_id] = get_tree_list_side_lenght(tree_list)

    return dict_of_tree_list, dict_of_side_length


# Load current best solution
current_solution_path = 'submission.csv'
dict_of_tree_list, dict_of_side_length = parse_csv(current_solution_path)

# Calculate current total score
current_score = get_total_score(dict_of_side_length)


for group_id_main in range(200, 1, -1):
    group_id_main = f'{int(group_id_main):03n}'
    print(f'Current box: {group_id_main}')

    group_id_prev = f'{int(group_id_main) - 1:03n}'
    best_side_length = dict_of_side_length[group_id_prev]
    best_tree_to_delete = None
    
    # Try to delete each tree one by one and select the best option
    for tree_to_delete in range(int(group_id_main)):
        candidate_tree_list = [tree.clone() for tree in dict_of_tree_list[group_id_main]]
        del candidate_tree_list[tree_to_delete]

        candidate_side_length = get_tree_list_side_lenght(candidate_tree_list)

        if candidate_side_length < best_side_length:
            print(f' improvement {best_side_length:0.8f} -> {candidate_side_length:0.8f}')
            best_side_length = candidate_side_length
            best_tree_to_delete = tree_to_delete

    # Save the best
    if best_tree_to_delete is not None:
        candidate_tree_list = [tree.clone() for tree in dict_of_tree_list[group_id_main]]
        del candidate_tree_list[best_tree_to_delete]
        
        dict_of_tree_list[group_id_prev] = candidate_tree_list
        dict_of_side_length[group_id_prev] = get_tree_list_side_lenght(candidate_tree_list)
    
# Recalculate current total score
new_score = get_total_score(dict_of_side_length)
print(f'{current_score=:0.8f} {new_score=:0.8f} ({current_score - new_score:0.8f})')

# Save results
tree_data = []
for group_name, tree_list in dict_of_tree_list.items():
    for item_id, tree in enumerate(tree_list):
        tree_data.append({
            'id': f'{group_name}_{item_id}',
            'x': f's{tree.center_x}',
            'y': f's{tree.center_y}',
            'deg': f's{tree.angle}'
        })
tree_data = pd.DataFrame(tree_data)
tree_data.to_csv('submission.csv', index=False)