## 🏆 Santa 2024 - The Perplexity Permutation Puzzle

* Author: **Roy Ma** (*creative-ataraxia*)
* Date: Feb 1st, 2025
* Objective: Optimize 6 string samples to achieve the lowest LM perplexity rating
* Results: placed **7.27%** on the leaderboard, `sample_0, 1, 2 and 5` achieve a score on par within top 3 level.

## Imports

In [2]:
# Standard Libs
import argparse
import copy
import gc
import hashlib
import itertools
import math
import os
import pickle
import random
import subprocess
from time import time
import warnings
from collections import Counter
from pathlib import Path
from typing import Generator, List, Optional, Union

# 3rd parties libs
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import transformers
from scipy.stats import spearmanr
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm.notebook import tqdm
from IPython.display import clear_output

os.environ["OMP_NUM_THREADS"] = "1"                                      # use a single thread
os.environ["TOKENIZERS_PARALLELISM"] = "false"                           # prevent race conditions
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index            # prevent special tokens to be used for scoring
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")    # check GPU
warnings.simplefilter("ignore")

## Constants

Defines the various paths, and asserts input data's consistency with competition formats

In [26]:
PATH_MODEL = Path("I:/Models/LLMs/gemma_2_9b")             # path to the local llm used for scoring
PATH_SAMPLE_INPUT = Path("../input/sample_submission.csv") # path to Kaggle default sample input
DF_SAMPLE = pd.read_csv(PATH_SAMPLE_INPUT)                    

PATH_SAVE = Path("./1st_save")                             # save location for various models checkpoints and search results
PATH_SAVE.mkdir(parents=True, exist_ok=True)

NUM_SAMPLES = len(DF_SAMPLE)
assert NUM_SAMPLES == 6                                    # total number of samples should be 6

LIST_NUM_WORDS = [len(DF_SAMPLE.loc[sample_id, "text"].split()) for sample_id in range(NUM_SAMPLES)]
assert LIST_NUM_WORDS == [10, 20, 20, 30, 50, 100]         # number of words in each sample needs to be consistent with competition requirements

LIST_WORD_TO_ID: list[dict[str, int]] = [                  # map words to IDs for pretraining the pruning nn later
    {word: i for i, word in enumerate(sorted(set(DF_SAMPLE.loc[sample_id, "text"].split())))}
    for sample_id in range(NUM_SAMPLES)
]

assert LIST_WORD_TO_ID[5] == {'advent': 0, 'and': 1, 'angel': 2, 'as': 3, 'bake': 4, 'beard': 5, 'believe': 6, 'bow': 7, 'candle': 8, 'candy': 9, 'card': 10, 'carol': 11, 'cheer': 12, 'chimney': 13, 'chocolate': 14, 'cookie': 15, 'decorations': 16, 'doll': 17, 'dream': 18, 'drive': 19, 'eat': 20, 'eggnog': 21, 'elf': 22, 'family': 23, 'fireplace': 24, 'from': 25, 'fruitcake': 26, 'game': 27, 'gifts': 28, 'gingerbread': 29, 'give': 30, 'greeting': 31, 'grinch': 32, 'have': 33, 'hohoho': 34, 'holiday': 35, 'holly': 36, 'hope': 37, 'in': 38, 'is': 39, 'it': 40, 'jingle': 41, 'joy': 42, 'jump': 43, 'kaggle': 44, 'laugh': 45, 'magi': 46, 'merry': 47, 'milk': 48, 'mistletoe': 49, 'naughty': 50, 'nice': 51, 'night': 52, 'not': 53, 'nutcracker': 54, 'of': 55, 'ornament': 56, 'paper': 57, 'peace': 58, 'peppermint': 59, 'poinsettia': 60, 'polar': 61, 'puzzle': 62, 'reindeer': 63, 'relax': 64, 'scrooge': 65, 'season': 66, 'sing': 67, 'sleep': 68, 'sleigh': 69, 'snowglobe': 70, 'star': 71, 'stocking': 72, 'that': 73, 'the': 74, 'to': 75, 'toy': 76, 'unwrap': 77, 'visit': 78, 'walk': 79, 'we': 80, 'wish': 81, 'with': 82, 'wonder': 83, 'workshop': 84, 'wrapping': 85, 'wreath': 86, 'you': 87, 'yuletide': 88}

## Scorer

> Main scoring module that computes perplexity scores from input texts. Used for all loss and optimization metrics.

In [27]:
class ParticipantVisibleError(Exception):
    pass

class PerplexityCalculator:
    """
    Main scoring module. Calculates the perplexity score of text(s) from a language model.

    Parameters
    ----------
    model_path : str
        Path to a pre-trained language model
    load_in_8bit : bool, default=False
        Use 8-bit quantization for the model. Requires CUDA.
    device_map : str, default="auto"
        Device mapping for the language model.
    """

    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = False,
        device_map: str = "auto",
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_path, padding_side="right"
        )
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != "cuda":
                raise ValueError("8-bit quantization requires CUDA device")
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == "cuda" else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")

        self.model.eval()

    # if you load all batches on the gpu, the algo will run much faster than if any parts are offloaded to cpu
    # so adjust batch_size accordingly to your how much vram you have
    def get_score(self, input_texts: Union[str, List[str]], batch_size=10) -> Union[float, List[float]]:
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []

        batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0)
        for j in range(batches):
            a = j * batch_size
            b = (j + 1) * batch_size
            input_batch = input_texts[a:b]

            with torch.no_grad():
                # Explicitly add sequence boundary tokens to the text
                text_with_special = [
                    f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"
                    for text in input_batch
                ]

                # Tokenize
                model_inputs = self.tokenizer(
                    text_with_special,
                    return_tensors="pt",
                    add_special_tokens=False,
                    padding=True,
                )

                if "token_type_ids" in model_inputs:
                    model_inputs.pop("token_type_ids")

                model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                # Get model output
                output = self.model(**model_inputs, use_cache=False)
                logits = output["logits"]

                label = model_inputs["input_ids"]
                label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID

                # Shift logits and labels for calculating loss
                shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                shift_labels = label[..., 1:].contiguous()  # Drop first input

                # Calculate token-wise loss
                loss = self.loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
                )

                loss = loss.view(len(logits), -1)
                valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
                loss = torch.sum(loss, -1) / valid_length

                loss_list += loss.cpu().tolist()

        ppl = [math.exp(i) for i in loss_list]

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        """Clears GPU memory by deleting references and emptying caches."""
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, "model"):
            del self.model
        if hasattr(self, "tokenizer"):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()

## Exploratory Data Analysis

### Initial EDA

First, let's look at what the data provided by the competition looks like

In [23]:
sample_df = pd.read_csv('../input/sample_submission.csv')
sample_df.head()

Unnamed: 0,id,text
0,0,advent chimney elf family fireplace gingerbrea...
1,1,advent chimney elf family fireplace gingerbrea...
2,2,yuletide decorations gifts cheer holiday carol...
3,3,yuletide decorations gifts cheer holiday carol...
4,4,hohoho candle poinsettia snowglobe peppermint ...


Looks like there are 2 columns, one is 'id', the other is 'text'; let's take a look at all the texts

In [12]:
for _, row in sample_df.iterrows():
    print(row.to_dict())

{'id': 0, 'text': 'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge'}
{'id': 1, 'text': 'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and'}
{'id': 2, 'text': 'yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice'}
{'id': 3, 'text': 'yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap'}
{'id': 4, 'text': 'hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle'}
{'id': 5, 'text': 'adven

- Looks like we have 6 strings of texts, made up with holiday-themed vocabularies separated by whitespaces.

- According to the rules of the competition, our goal is to rearrange the words in a way that results in:
    - the **lowest perplexity score** as evaluated by an LLM.
    - we'll use `Gemma-2 9b` as recommended by the competition.
- Perplexity bascially means how likely the LLM expect its next token to be.
    - if a token has low perplexity, it means the LLM is likely to sample it as the next token in sequence.
    - else if a token has high perplexity, it mean the LLM is unlikely to sample it as the next token in sequence.
- The final leaderboard score is the **average perplexity score** across all 6 sample strings.

### Brainstorm Strategies

> Since the goal is to find a specific permutation for each string, that results in the lowest perplexity score, can we just brute force all permutations to find the best score?

In [25]:
sample_df["word_count"] = sample_df["text"].str.split().str.len()
sample_df["permutations"] = sample_df["word_count"].apply(math.factorial).apply(lambda x: f"{x:.2e}")
sample_df[["id", "word_count", "permutations"]].style.hide(axis="index")

id,word_count,permutations
0,10,3630000.0
1,20,2.43e+18
2,20,2.43e+18
3,30,2.65e+32
4,50,3.04e+64
5,100,9.33e+157


- As we can see, only the 1st sample `sample_0`, with a word count of 10, with a total number of about 3.6 million permutations, is feasible for brute-force.
- To brute-force the other strings would take until the end of the universe.

> Then, since the competition's evalutation metric is perplexity, and perplexity for `language models` should mean that the better the text flows according to natural language, the better the perplexity, right? Let's do some experiments.

In [28]:
# 1st, instantiate the scorer
scorer = PerplexityCalculator(model_path=str(PATH_MODEL))

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [31]:
# test if scorer works
sample_0 = sample_df.text[0]
print(f"{sample_0=}")
score = scorer.get_score(sample_0)
print()
print(f"sample 0 perplexity score: {score}")

sample_0='advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge'

sample 0 perplexity score: 3857.64626282737


- The scorer works correctly, currently, the default arrangement of sample 0 scored about `3857.64`
- Next, let's see what the longest sample, `sample 5` scores:

In [32]:
sample_5 = sample_df.text[5]
print(f"{sample_5=}")
score = scorer.get_score(sample_5)
print()
print(f"sample 5 perplexity score: {score}")

sample_5='advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle'

sample 5 perplexity score: 354.636652059297


- Default sample 5 scored much better than sample 1 due to it having many more "stopwords" available.
- Such as *"the, of, and, that" etc.*
- Now, what if we put all these stopwords at the start of the sentence? you would think that it would score worse, since the sentence would not make sense anymore, right?

In [38]:
def rearrange_words(sentence, stopwords):
    words = sentence.split()
    stop_words = [word for word in words if word in stopwords]
    other_words = [word for word in words if word not in stopwords]
    return " ".join(stop_words + other_words)

stopwords = "and as from have in is it not of that the to we with you"
sorted_sample_5 = rearrange_words(sample_5, stopwords)
print(sorted_sample_5)

score = scorer.get_score(sorted_sample_5)
print()
print(f"new sample 5 perplexity score: {score}")

the and and of the is the to of and in that have it not with as you from we advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake sleep night laugh yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer eat visit relax unwrap hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel kaggle

new sample 5 perplexity score: 233.4866728754002


- suprise, sorted sample 5 actually scores much better, even with a non-grammatical arrangement like this.
- this is likely due to the 'quirks' of how LLM progenate their embeddings layers.
- so, if for the purpose of receiving a better perplexity score, focusing solely on grammar and semantics will not do.

> So, it looks like we'll need to employ optimization algorithms to find the permutations we need; there are many optimization algorithms, such as simulated annealing, genetic algorithm etc. As usual, the devil is in the (implementation) details, below is an example algorithm called `Iterated Local Search`.

# Iterated Local Search

- The core idea of the **iterated local search** is that
    - we perform **local search** and greedily accept better solutions:
```pseudocode
FUNCTION solve(initial_perm)
    best_perm ← initial_perm
    LOOP FOREVER:
        perms ← best_perm
        REPEAT n TIMES:
            perms ← kick(perms)
            perms ← local_search(perms)
            IF score(perms) < score(best_perm) THEN
                best_perm ← perms

```
- In **local search**, we:
    - using depth-first-search, we continue to search as long as a better solution can be found from this `initial_perm`
```pseudocode
FUNCTION local_search(initial_perm)
    perms ← initial_perm
    LOOP FOREVER:
        result ← DFS(perms)
        IF result IS NULL THEN
            RETURN perms
        perms ← result 
```

- In **DFS**, we:
    - immediately return if a better solution is found
    - else, accept solutions with a worse score, as allowed by a `threshold` parameter
    - and recursively search until a valid result or no result can be found

```pseudocode
FUNCTION DFS(best_score, perms, depth)
    FOR EACH neighbor IN make_neighbors(perms) DO:
        IF score(neighbor) < score(perms) THEN
            RETURN neighbor
        ELSE IF score(neighbor) < get_threshold(best_score, depth) THEN
            result ← DFS(best_score, neighbor, depth + 1)
            IF result IS NOT NULL THEN
                RETURN result
```

> The following codes are the implementations for the above pseudocodes

## Utils

Varous utility functions to aid the run. Refer to the docstrings for the purposes of the util functions

In [39]:
def get_path_words_best(n_idx):
    """
    Retrieve the permutation with the best score for a sample.

    Searches the saved text and returns the best score and corresponding words.
    Assert the permutation's validity by comparing its sorted words to the original text.

    Parameters
    ----------
    n_idx : int
        The sample index to retrieve.

    Returns
    -------
    tuple[float, list[str]] or (None, None)
        The best score and permutation if available; otherwise, (None, None).
    """
    path_save_idx = PATH_SAVE / f"{n_idx:04d}"
    words_original = DF_SAMPLE.loc[n_idx, "text"].split(" ")
    path_txt = path_save_idx.glob("*.txt")
    list_path_txt = list(path_txt)

    if not list_path_txt:
        return None, None
    
    list_scores = [float(path.stem.split("_")[0]) for path in list_path_txt]
    idx_min = np.argmin(list_scores)
    score = list_scores[idx_min]
    path_min = list_path_txt[idx_min]
    text_min = path_min.read_text()
    words_min = text_min.split(" ")
    
    assert sorted(words_min) == sorted(words_original)
    return score, words_min


def save_text(get_score, n_idx, text, verbose=0):
    """
    Save a permutation along with its score to disk.

    Again validates that the submitted text is a valid permutation of the original,
    calculates its score, and saves the text with the score and a hash-based filename.

    Parameters
    ----------
    get_score : function
        The scoring function
    n_idx : int
        The sample id
    text : str
        The permutation to be saved.
    verbose : int, optional
        Verbosity: prints score if >=1 and text if >=2, default 0.

    Returns
    -------
    float or None
        The calculated score if saving is successful; otherwise, None.
    """
    path_save_idx = PATH_SAVE / f"{n_idx:04d}"
    path_save_idx.mkdir(exist_ok=True)
    text_original = DF_SAMPLE.loc[n_idx, "text"]
    words_original = text_original.split(" ")
    words = text.split(" ")
    
    if sorted(words) != sorted(words_original):
        print(f"[Warning] words are not the same with original: {words} != {words_original}")
        return
    
    text = " ".join(words)
    score = get_score(n_idx, text)
    
    if verbose >= 1: print(f"score:{score:.4f}")
    if verbose >= 2: print(text)
    
    md5 = hashlib.md5(text.encode()).hexdigest()
    path_save_text = path_save_idx / f"{score:.4f}_{md5}.txt"
    
    with path_save_text.open("w") as f:
        f.write(text)
    
    return score


def load_score_memo() -> tuple[dict[str, float], dict[str, float]]:
    """
    Load score memoization from disk.

    Retrieves two dictionaries: successful scores and errored scores.
    These memos avoid recomputation during optimization run.

    Returns
    -------
    tuple[dict[str, float], dict[str, float]]
        A tuple containing the successful score memo and errored score memo.
    """
    def load(name: str) -> dict[str, float]:
        path_score_memo = PATH_SAVE / name
        if path_score_memo.exists():
            with path_score_memo.open("rb") as f:
                return pickle.load(f)
        return {}
    
    return load("score_memo.pkl"), load("score_memo_with_error.pkl")


def save_score_memo(
    score_memo: dict[str, float],
    score_memo_with_error: dict[str, float],
):
    """
    Save updated score memoization dictionaries to disk.

    Merges new score memo entries with existing ones and writes the updated
    dictionaries back to disk to enable caching during the optimization run.

    Parameters
    ----------
    score_memo : dict[str, float]
        Dictionary mapping text strings to their scores.
    score_memo_with_error : dict[str, float]
        Dictionary mapping text strings to their scores, including error cases.
    """
    def save(name: str, score_memo: dict[str, float]):
        path_score_memo = PATH_SAVE / name
        with path_score_memo.open("wb") as f:
            pickle.dump(score_memo, f)
    
    score_memo_original, score_memo_with_error_original = load_score_memo()
    score_memo_original.update(score_memo)
    score_memo_with_error_original.update(score_memo_with_error)
    
    save("score_memo.pkl", score_memo_original)
    save("score_memo_with_error.pkl", score_memo_with_error_original)


def _get_score(
    scorer,
    score_memo: dict[str, float],
    score_memo_with_error: dict[str, float],
    text: Union[str, List[str]],
) -> Union[float, List[float]]:
    """
    Get the perplexity score(s) for input text(s), use memoization to avoid recomputation.

    Checks if the score for the given text or texts is already cached in the
    score memos. If not, computes the score using the provided scorer, updates
    the memo, and returns the score(s).

    Parameters
    ----------
    scorer : PerplexityCalculator
        An instance that can compute the score for text.
    score_memo : dict[str, float]
        Cache of previously computed scores.
    score_memo_with_error : dict[str, float]
        Cache for scores computed when errors occurred.
    text : str or list[str]
        A single text string or a list of text strings for which to compute perplexity.

    Returns
    -------
    float or list[float]
        The perplexity score if a single string is provided; otherwise, return a list of scores.

    Raises
    ------
    ValueError
        If `text` is neither a string nor a list of strings.
    """
    if isinstance(text, str):
        if text in score_memo:
            return score_memo[text]
        
        score = scorer.get_score(text)
        score_memo[text] = score
        return score
    
    elif isinstance(text, list):
        list_text_new = [t for t in text if t not in score_memo and t not in score_memo_with_error]
        
        if list_text_new:
            list_score_new = scorer.get_score(list_text_new)
            for t, s in zip(list_text_new, list_score_new):
                score_memo_with_error[t] = s
        
        return [score_memo.get(t, score_memo_with_error.get(t, None)) for t in text]
    
    else:
        raise ValueError("text is not str nor list[str]")


## Pre-train Pruning NN

- In this module we train a lightweight CNN, to be used for pruning permutation candidates generated.
- So we save on scoring compute cost when doing optimization runs.
- we will need a score memo saved first to use as training data.
- I will not run the training here in this notebook, but I've included the training results in the main run below.

In [45]:
class SEBlock(nn.Module):
    """
    Squeeze-and-Excitation block. *arXiv:1709.01507*

    This block adaptively recalibrates channel-wise feature responses by explicitly modeling
    interdependencies between channels. It performs global average pooling followed by two
    fully connected layers with a ReLU and sigmoid activation to compute per-channel weights.

    tldr: emphasize important feature channels
    """
    def __init__(self, channels: int, reduction: int = 16) -> None:
        """
        Parameters
        ----------
        channels : int
            Number of input channels.
        reduction : int, optional
            Reduction factor for the hidden layer in the SE block, default 16.
        """
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (B, C, L).

        Returns
        -------
        torch.Tensor
            Output tensor after channel-wise recalibration.
        """
        b, c, l = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1)
        return x * y


class ResidualBlock(nn.Module):
    """
    Residual block with two convolutional layers and an SEBlock.
    Implements a residual connection where the input is added to the processed features.
    
    tldr: helps in training deeper networks.
    """
    def __init__(self, channels: int, kernel_size: int = 3) -> None:
        """
        Parameters
        ----------
        channels : int
            Number of input and output channels.
        kernel_size : int, optional
            Convolutional kernel size, default 3.
        """
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=padding)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size, padding=padding)
        self.se = SEBlock(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (B, channels, L).

        Returns
        -------
        torch.Tensor
            Output tensor after applying convolutional layers, SE block, and residual addition.
        """
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.se(out)
        out += x
        out = self.relu(out)
        return out


class PruneNet(nn.Module):
    """
    Pretraining model architecture: embedding layer, convolutional stem, ResidualBlocks,
    and final fully connected layer to output a single scalar per input sequence.
    """
    def __init__(self, vocab_size: int, channels: int, num_blocks: int) -> None:
        """
        Parameters
        ----------
        vocab_size : int
            Size of the vocabulary.
        channels : int
            Number of channels for the embeddings and convolutional layers.
        num_blocks : int
            Number of ResidualBlocks to stack.
        """
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, channels)
        self.conv_stem = nn.Conv1d(
            channels, channels, kernel_size=7, stride=1, padding=3, bias=False
        )
        self.relu_stem = nn.ReLU(inplace=True)
        self.blocks = nn.ModuleList(
            [ResidualBlock(channels) for _ in range(num_blocks)]
        )
        self.fc = nn.Linear(channels, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (B, L) containing token indices.

        Returns
        -------
        torch.Tensor
            Output tensor of shape (B, 1) representing a scalar score for each sequence.
        """
        x = self.embedding(x)                        # (B, L, channels)
        x = x.transpose(1, 2)                        # (B, channels, L)
        x = self.conv_stem(x)
        x = self.relu_stem(x)
        for block in self.blocks:
            x = block(x)
        x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)  # (B, channels)
        x = self.fc(x)                               # (B, 1)
        return x


class ScoreDataset(Dataset):
    """
    PyTorch Dataset for pretraining, holding input text sequences and their target scores.

    This dataset stores the input token tensors and their corresponding target scores (log values).
    """
    def __init__(self, X: torch.Tensor, y: torch.Tensor) -> None:
        """
        Parameters
        ----------
        X : torch.Tensor
            Tensor of input sequences (token IDs) of shape (N, L).
        y : torch.Tensor
            Tensor of target scores (log values) of shape (N,).
        """
        self.X = X
        self.y = y

    def __len__(self) -> int:
        """
        Return the number of samples in the dataset.
        """
        return len(self.X)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieve the input and target score for the given index.

        Parameters
        ----------
        idx : int
            Index of the sample.

        Returns
        -------
        tuple[torch.Tensor, torch.Tensor]
            A tuple (X, y) where X is the token tensor and y is the target score.
        """
        X = self.X[idx].long()
        y = self.y[idx]
        return X, y


def prepare_dataset(training_data: list[dict[bytes, float]], sample_id: int) -> tuple[dict[str, int], Dataset]:
    """
    Prepare the training dataset for a specific sample.

    Converts the compressed text data into a tensor of token IDs and corresponding log scores,
    and returns the word-to-ID mapping along with the constructed PyTorch Dataset.

    Parameters
    ----------
    training_data : list[dict[bytes, float]]
        A list where each element corresponds to a sample and maps compressed text (bytes)
        to its associated score.
    sample_id : int
        The index of the sample to prepare the dataset for.

    Returns
    -------
    tuple[dict[str, int], Dataset]
        A tuple containing the word-to-ID mapping and the constructed ScoreDataset.
    
    Raises
    ------
    ValueError
        When input is sample_id 1 or 2.
    """
    word_to_id = LIST_WORD_TO_ID[sample_id]
    length = LIST_NUM_WORDS[sample_id]

    if sample_id in [1, 2]:
        raise ValueError("sample_id 1 and 2 are not supported")
    
    num_data = len(training_data[sample_id])
    X = torch.empty((num_data, length), dtype=torch.int8)
    y = torch.empty(num_data, dtype=torch.float)

    for idx, (compressed_text, score) in enumerate(tqdm(training_data[sample_id].items(), mininterval=30)):
        assert len(compressed_text) == length
        X[idx] = torch.tensor(list(compressed_text), dtype=torch.int8)
        y[idx] = math.log(score)

    # print(f"[DEBUG] {word_to_id=}")
    dataset = ScoreDataset(X, y)
    return word_to_id, dataset


def train_model(sample_id: int = 5) -> None:
    """
    Train the model on pretraining data for a given sample.

    Loads training data, splits it into training and validation sets, and trains the model
    using L1 loss on log-transformed scores. Save model checkpoints after each epoch.

    Parameters
    ----------
    sample_id : int, optional
        The sample index to train on (default 5th sample).
    """
    _, training_data = load_score_memo()
    path_pretrain = PATH_SAVE / Path("pretrain")
    path_pretrain.mkdir(parents=True, exist_ok=True)
    num_epochs = 20
    batch_size = 4096
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    word_to_id, dataset = prepare_dataset(training_data, sample_id)
    del training_data

    total_size = len(dataset)
    val_size = int(total_size * 0.05)        # split 5% data as validation
    train_size = total_size - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=4,
        shuffle=True,
        drop_last=True,
    )
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4, shuffle=False)
    model = PruneNet(vocab_size=len(word_to_id), channels=128, num_blocks=12)
    model = model.to(device)
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

    optimizer = optim.AdamW(model.parameters(), lr=0.005)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        count = 0
        pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", mininterval=30)
        for X, y in pbar:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)  # (B, 1)
            loss = F.l1_loss(outputs.squeeze(-1), y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * X.size(0)
            count += X.size(0)
            pbar.set_postfix(refresh=False, loss=loss.item())

        avg_loss = running_loss / count

        model.eval()
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(device), y.to(device)
                preds: torch.Tensor = model(X).squeeze(-1)
                all_preds.extend(preds.detach().cpu().tolist())
                all_targets.extend(y.detach().cpu().tolist())
        corr, _ = spearmanr(all_targets, all_preds)
        print(
            f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_loss:.4f}, "
            f"Val Spearman: {corr:.4f}"
        )

        torch.save(
            {"model": model.state_dict(), "word_to_id": word_to_id},
            path_pretrain / f"model_{sample_id}_epoch_{epoch+1}.pt",
        )

### Run Pretraining

In [6]:
# only run if previous score memos are saved, to be used as training data for the nn
# train_model()

## Main

> This is the main search loop module

### Score Estimator

- Use the pretrained pruning neural network to estimate scores for candiate permutations. 
    - So we can prune potential candidates and don't have to fully score everything through the LLM, which is expensive.
- The estimator will try to load any model checkpoints generated during the optimization run first, else fallback to the pretrained checkpoints.
- During optimization run, the nn's weights will be updated on-the-fly.

In [40]:
class ScoreEstimator:
    """
    Use the pretrained pruning neural network to estimate scores for candiate permutations,
    so we don't have to pass every candidate into the full LM to score, which is expensive.
    """
    def __init__(
        self,
        sample_id: int,
        epoch: int,
        device: torch.device,
    ):
        """
        Parameters
        ----------
        sample_id : int
            Index of the sample for which the estimator is used on.
        epoch : int
            Epoch number of the pretrained model (used when no online model exists).
        device : torch.device
            The device (CPU or CUDA) on which the model will run.
        """
        self.sample_id = sample_id
        self.length = LIST_NUM_WORDS[sample_id]

        # Prepare the data for the online model.
        online_model_dir = PATH_SAVE / "online"
        online_model_dir.mkdir(parents=True, exist_ok=True)
        self.online_model_path = online_model_dir / f"model_{sample_id}.pt"
        self.pretrained_model_path = (PATH_SAVE / f"pretrain/model_{sample_id}_epoch_{epoch}.pt")
        
        # Load an existing online model if available, otherwise fall back to the pretrained model.
        if self.online_model_path.exists():
            print(f"[ScoreEstimator] Load online model: {self.online_model_path}")
            checkpoint = torch.load(self.online_model_path, map_location=device, weights_only=True)
        elif self.pretrained_model_path.exists():
            print(f"[ScoreEstimator] Load pretrained model: {self.pretrained_model_path}")
            checkpoint = torch.load(self.pretrained_model_path, map_location=device, weights_only=True)
        else:
            # warnings.warn(f"[ScoreEstimator] Checkpoint not found: {sample_id}")
            checkpoint = None

        # Retrieve word-to-ID mapping for the sample.
        self.word_to_id = LIST_WORD_TO_ID[sample_id]

        self.device = device
        self.model = PruneNet(vocab_size=len(self.word_to_id), channels=128, num_blocks=12).to(device)
        if checkpoint is not None:
            self.model.load_state_dict(checkpoint["model"])
        self.model.eval()  # remember to set to eval() to lock weights
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.00005) # tune eta here

        # For debugging, buffers to collect texts, scores, and predictions.
        self.buffer_texts = []
        self.buffer_scores = []
        self.buffer_predictions = []
        self.update_count = 0


    def estimate_scores(self, texts: list[str]) -> np.ndarray:
        """
        Converts each text into a sequence of token IDs using the word-to-ID mapping,
        then computes the predicted scores for the optimization run, and for online training.

        Parameters
        ----------
        texts : list[str]
            A list of text strings representing candidate solutions.

        Returns
        -------
        np.ndarray
            An array of log perplexity scores (one per text).
        """
        X_list = []

        for text in texts:
            words = text.split()
            assert len(words) == self.length
            X_list.append([self.word_to_id[w] for w in words])

        X = torch.tensor(X_list, dtype=torch.long, device=self.device)
        with torch.no_grad():
            preds: torch.Tensor = self.model(X).squeeze(-1)  # (B,)

        return preds.detach().cpu().numpy()                  # log perplexity


    def update_parameters(self, texts: list[str], scores: list[float]):
        """
        Update the model weights using the data and scores obtained during the optimization run.
        This is the "online training" aspect.

        Parameters
        ----------
        texts : list[str]
            A list of text strings representing candidate solutions.
        scores : list[float]
            A list of target scores corresponding to the texts.
        """
        X_list = []

        for text in texts:
            words = text.split()
            assert len(words) == self.length
            X_list.append([self.word_to_id[w] for w in words])

        X = torch.tensor(X_list, dtype=torch.long, device=self.device)
        self.model.train()
        pred: torch.Tensor = self.model(X).squeeze(-1)  # (B,)
        target = torch.tensor(scores, dtype=torch.float, device=self.device).log()
        loss = F.l1_loss(pred, target)
        loss.backward()

        self.optimizer.step()
        self.optimizer.zero_grad(set_to_none=True)
        self.model.eval()

        self.buffer_texts.extend(texts)
        self.buffer_scores.extend(scores)
        self.buffer_predictions.extend(pred.detach().cpu().tolist())
        self.update_count += 1
        
        if self.update_count % 256 == 0:
            corr, _ = spearmanr(self.buffer_scores, self.buffer_predictions)
            print(f"[ScoreEstimator] Spearman: {corr:.4f}")
            self.buffer_texts.clear()
            self.buffer_scores.clear()
            self.buffer_predictions.clear()


    def save_model(self):
        torch.save(
            {"word_to_id": self.word_to_id, 
             "model": self.model.state_dict()},
             self.online_model_path,
        )

### Make Neighbors

- Using 2 types of moves to create new perms to continue the optimization run.
    - Move A: Inserting into Sorted Segments.
    - Move B: Random Center-Based Shifts.

In [41]:
def make_neighbors(words: list[str]) -> Generator[tuple[list[str], tuple], None, None]:
    """
    Generate neighboring permutations for a given list of words.

    Identifies sorted segments in the list, then:
    1. Merge parts of the candidate with parts of the sorted segments.
    2. Moving a block of words to different positions.

    Tracks all seen permutations in a set to prevent duplicates.
    
    Parameters
    ----------
    words : list[str]
        The original sequence of words to perturb.

    Yields
    ------
    tuple[list[str], tuple]
        A tuple where the first element is a new permutation and the
        second element is a tuple encoding metadata about the type of change.
    """
    words = words.copy()
    found = {tuple(words)}

    # Get continous segments in the words that are already sorted. 
    sorted_segments = []
    for i, (left_word, right_word) in enumerate(zip(words, words[1:])):
        if left_word <= right_word:
            if sorted_segments and sorted_segments[-1][1] == i + 1:
                sorted_segments[-1][1] = i + 2
            else:
                sorted_segments.append([i, i + 2])

    # Keep the sorted segments but only if >= 4 long.
    sorted_segments = [(left, right) for left, right in sorted_segments if right - left >= 4]

    # set max_length to 3, but decrease to 2 if it's sample 4 & 5, else too expensive.
    max_length = 2 if len(words) >= 50 else 3
    
    # Move A: Inserting into Sorted Segments
    # Loop Over Possible Segment Lengths
    for length in range(1, max_length + 1):
        if length >= 2:
            results = []
            # Loop Over Source Segments
            for source_l in range(len(words) - length + 1):
                source_r = source_l + length
                # Loop Over Target Segments
                for target_l, target_r in sorted_segments:
                    if source_r <= target_l:
                        # Insert the words to create perms
                        permuted = (
                            words[:source_l]
                            + words[source_r:target_l]
                            + sorted(words[source_l:source_r] + words[target_l:target_r])
                            + words[target_r:]
                        )
                    elif target_r <= source_l:
                        permuted = (
                            words[:target_l]
                            + sorted(words[target_l:target_r] + words[source_l:source_r])
                            + words[target_r:source_l]
                            + words[source_r:]
                        )
                    # continue if overlapping
                    else:
                        continue
                    # make sure no duplicates
                    if (t := tuple(permuted)) not in found:
                        found.add(t)
                        results.append((permuted, (source_l, source_r, target_l, target_r, 3)))

            random.shuffle(results)
            # return each result
            yield from results

        # Move B: Random Center-Based Shifts
        # Valid centers
        r = range(length, len(words) - length + 1) 
        for center in random.sample(r, len(r)):
            results = []
            # Right-centered
            right = center + length
            # Shift right to left
            for left_length in itertools.count(length):
                left = center - left_length
                if left < 0:
                    break
                permuted = (
                    words[:left]
                    + words[center:right]
                    + words[left:center]
                    + words[right:]
                )
                if (t := tuple(permuted)) not in found:
                    found.add(t)
                    results.append((permuted, (left, center, right, 0)))
            
            # Left-centered
            left = center - length
            # Shift left to right
            for right_length in itertools.count(length + 1):
                right = center + right_length
                if right > len(words):
                    break
                permuted = (
                    words[:left]
                    + words[center:right]
                    + words[left:center]
                    + words[right:]
                )
                if (t := tuple(permuted)) not in found:
                    found.add(t)
                    results.append((permuted, (left, center, right, 0)))

            random.shuffle(results)
            yield from results

In [42]:
def free_memory():
    gc.collect()
    torch.cuda.empty_cache()

### Optimizer

> This is the main optimization implementation that controls all the search mechanisms.

- the entry point is the `run()` method.
- during each iteration of `run()`, `hillclimbing()` method is called that implemented the local search logic.
- if the search is stuck, a random perturbation is applied to try to jostle the run out of the local minima.

In [43]:
class Optimization:
    """
    Orchestrates the entire optimization process for all samples.
        - local search / hill climbing
        - perturbations / kicks
        - neural network online update
    """
    def __init__(
        self,
        flag_use_best=True,
        flag_shuffle=True,
    ):
        """
        Parameters
        ----------
        flag_use_best : bool, optional
            If True, start from the best saved solution; otherwise, use the input text,
        flag_shuffle : bool, optional
            If not using the best saved solution, shuffle the input words to create an
            initial candidate
        """
        self.flag_use_best = flag_use_best
        self.flag_shuffle = flag_shuffle

        # Instance the scorer, score memos, and record last memo save time.
        self.calculator = PerplexityCalculator(model_path=str(PATH_MODEL))
        self.score_memo, self.score_memo_with_error = load_score_memo()
        self.last_time_score_memo_saved = time()

        # Initialize ScoreEstimators for each sample
        self.score_estimators = [
            ScoreEstimator(
                sample_id=sample_id,
                epoch=epoch,
                device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            )
            # use only the ckp for sample 4, and 5 (the longest samples)
            for sample_id, epoch in enumerate([-1, -1, -1, -1, 20, 11]) 
        ]

        # Init to hold current best solutions
        self.list_words_best: list[list[str]] = []
        self.list_perplexity_best: list[float] = []

        for idx in range(NUM_SAMPLES):
            if self.flag_use_best:
                _, list_words = get_path_words_best(idx)
                assert list_words is not None
            else:
                text: str = DF_SAMPLE.iloc[idx, 1]
                list_words = text.split()
                if self.flag_shuffle:
                    random.shuffle(list_words)

            text = " ".join(list_words)
            self.list_words_best.append(list_words.copy())
            score_new = self._calc_perplexity(idx, text)
            self.list_perplexity_best.append(score_new)

            print(f"idx:{idx} score:{score_new:.4f}")

        # Keep an all-time record of best solutions (for recovery when stuck).
        self.list_words_best_all = copy.deepcopy(self.list_words_best)
        self.list_perplexity_best_all = copy.deepcopy(self.list_perplexity_best)

        # Kick counter for applying kick
        self.list_num_kick = [1] * NUM_SAMPLES


    def _calc_perplexity(self, n_idx: int, text: Union[str, list[str]]) -> Union[float, list[float]]:
        """
        Calculate the score for a given text or list of texts.

        Parameters
        ----------
        n_idx : int
            The sample index.
        text : str or list[str]
            The candidate text(s) to score.

        Returns
        -------
        float or list[float]
            The perplexity score(s).
        """
        return _get_score(self.calculator, self.score_memo, self.score_memo_with_error, text)


    def _get_best(self, n_idx: int) -> tuple[list[str], float]:
        """
        Retrieve the current best solution and its score for a given sample from the class attributes.

        Parameters
        ----------
        n_idx : int
            The problem index.

        Returns
        -------
        tuple[list[str], float]
            The best word list and its perplexity score.
        """
        return self.list_words_best[n_idx], self.list_perplexity_best[n_idx]


    def _update_best_all(self, n_idx: int, words: list[str], perplexity: float):
        """
        Update class attribute of the all-time best solution for a sample if the new candidate is better.

        Parameters
        ----------
        n_idx : int
            The problem index.
        words : list[str]
            The candidate word list.
        perplexity : float
            The candidate perplexity score.
        """
        if perplexity < self.list_perplexity_best_all[n_idx]:
            self.list_words_best_all[n_idx] = words.copy()
            self.list_perplexity_best_all[n_idx] = perplexity


    def _get_best_all(self, n_idx: int) -> tuple[list[str], float]:
        """
        Retrieve the all-time best solution for a given sample.

        Parameters
        ----------
        n_idx : int
            The problem index.

        Returns
        -------
        tuple[list[str], float]
            The all-time best word list and its perplexity score.
        """
        return self.list_words_best_all[n_idx], self.list_perplexity_best_all[n_idx]


    def _hillclimbing(
        self,
        n_idx: int,
        words_best: list[str],
        perplexity_best: float,
        score_estimator: ScoreEstimator,
        iter_total: int = 500, # set total iterations to run here
        pbar=None,
        print_every: int = 100,
    ) -> tuple[list[str], float]:
        """
        Perform hill-climbing search to find a better solution.

        From the current best solution -> recursive DFS on neighbors ->
        Use customized threshold greedy selection based on estimated scores.

        Parameters
        ----------
        n_idx : int
            The problem index.
        words_best : list[str]
            The current best candidate solution.
        perplexity_best : float
            The perplexity score of the current best solution.
        score_estimator : ScoreEstimator
            An instance used to estimate scores and update model parameters.
        iter_total : int, optional
            Total iterations to try in the search, by default 500.
        pbar : tqdm instance, optional
            Progress bar instance for updating progress.
        print_every : int, optional
            Frequency (in iterations) to print status messages, by default 100.

        Returns
        -------
        tuple[list[str], float]
            The improved word list and its perplexity score.
        """
        iter_count = 0
        # Local class to record accepted and rejected candidates.
        class Stats:
            def __init__(self, max_value: int):
                self.max_value = max_value
                self.accepted = Counter()
                self.rejected = Counter()

            def summary(self) -> str:
                n_bins = 8
                accepted = [0] * n_bins
                for value, count in self.accepted.items():
                    assert 0 <= value < self.max_value
                    accepted[value * n_bins // self.max_value] += count
                rejected = [0] * n_bins
                for value, count in self.rejected.items():
                    assert 0 <= value < self.max_value
                    rejected[value * n_bins // self.max_value] += count
                return (
                    f" accepted:{accepted}, rejected:{rejected}"
                    f"     - total:{sum(accepted) + sum(rejected)}"
                )

        batch_size = 128
        stats = Stats(max_value=batch_size)

        visited = set()

        def search(words: list[str], depth: int = 0) -> tuple[float, list[str], list[int]]:
            """
            Recursive search to find improved solutions.
            
            Tracks visited candidates to avoid cycles and applies a depth-dependent
            threshold to decide whether to explore a candidate further.

            Parameters
            ----------
            words : list[str]
                Current candidate solution.
            depth : int, optional
                Current recursion depth, by default 0.

            Returns
            -------
            tuple[float, list[str], list[int]]
                The improved perplexity, the corresponding word list, and neighbor
                type information.
            """
            nonlocal iter_count
            visited.add(tuple(words))

            # customized thresholds for different samples; (refer to word-id mapping)
            if n_idx == 0:
                depth_to_threshold = {
                    0: 1.2, 1: 1.12, 2: 1.08, 3: 1.06, 4: 1.04, 5: 1.03, 6: 1.025, 7: 1.02, 8: 1.015, 9: 1.01, 10: 1.01,
                    11: 1.01, 12: 1.005, 13: 1.005, 14: 1.002, 15: 1.002, 16: 1.002, 17: 1.001, 18: 1.001, 19: 1.001, 20: 1.0,
                }
            elif n_idx in [1, 2]:
                depth_to_threshold = {
                    0: 1.2, 1: 1.12, 2: 1.08, 3: 1.06, 4: 1.04, 5: 1.03, 6: 1.025, 7: 1.02, 8: 1.015, 9: 1.01, 10: 1.01,
                    11: 1.01, 12: 1.005, 13: 1.005, 14: 1.002, 15: 1.002, 16: 1.002, 17: 1.001, 18: 1.001, 19: 1.001, 20: 1.0,
                }
            elif n_idx == 3:
                depth_to_threshold = {
                    0: 1.1, 1: 1.06, 2: 1.04, 3: 1.03, 4: 1.02, 5: 1.015, 6: 1.01, 7: 1.008, 8: 1.006, 9: 1.005, 10: 1.004,
                    11: 1.003, 12: 1.003, 13: 1.002, 14: 1.002, 15: 1.001, 16: 1.001, 17: 1.001, 18: 1.001, 19: 1.001, 20: 1.0,
                }
            elif n_idx == 4:
                depth_to_threshold = {
                    0: 1.05, 1: 1.03, 2: 1.02, 3: 1.015, 4: 1.01, 5: 1.008, 6: 1.006, 7: 1.004, 8: 1.003, 9: 1.002, 10: 1.002,
                    11: 1.002, 12: 1.002, 13: 1.002, 14: 1.001, 15: 1.001, 16: 1.001, 17: 1.001, 18: 1.001, 19: 1.001, 20: 1.0,
                }
            elif n_idx == 5:
                depth_to_threshold = {
                    0: 1.015, 1: 1.01, 2: 1.007, 3: 1.005, 4: 1.004, 5: 1.0035, 6: 1.003, 7: 1.0025, 8: 1.002, 9: 1.0015, 10: 1.001,
                    11: 1.001, 12: 1.001, 13: 1.001, 14: 1.001, 15: 1.001, 16: 1.001, 17: 1.001, 18: 1.001, 19: 1.001, 20: 1.0,
                }
            else:
                raise ValueError(f"Invalid n_idx: {n_idx}")

            # Generate neighbor permutations to evaluate
            neighbors = make_neighbors(words)
            max_depth = depth

            # Continuously generate candidate neighbors
            for _ in itertools.count(0):
                list_words_nxt: list[list[str]] = []
                list_texts_nxt: list[str] = []
                list_neighbor_type: list = []

                # Allow more candidate for more depth
                num_candidates = 2048 if depth < 2 else 4096 if depth < 5 else 8192
                while len(list_words_nxt) < num_candidates:
                    try:
                        words_nxt, neighbor_type = next(neighbors)
                        if tuple(words_nxt) in visited:
                            continue
                        list_words_nxt.append(words_nxt)
                        list_texts_nxt.append(" ".join(words_nxt))
                        list_neighbor_type.append(neighbor_type)
                    except StopIteration:
                        break
                if len(list_words_nxt) < min(num_candidates, int(1.5 * len(words) ** 2)):
                    # return None score improvements, None solution found, None perm meta, max_depth
                    return None, None, None, max_depth

                # Prune and select the top 112 candidates by estimated score plus 16 random ones.
                estimated_scores = score_estimator.estimate_scores(list_texts_nxt)
                indices_sorted = np.argsort(estimated_scores).tolist()
                indices_keep = indices_sorted[:112] + random.sample(indices_sorted[112:], 16)
                assert len(indices_keep) == batch_size

                list_words_nxt = [list_words_nxt[i] for i in indices_keep]
                list_texts_nxt = [list_texts_nxt[i] for i in indices_keep]
                list_neighbor_type = [list_neighbor_type[i] for i in indices_keep]

                # Scores candidates
                list_perplexity_nxt_with_error = self._calc_perplexity(n_idx, list_texts_nxt)

                # pass to nn for online training
                score_estimator.update_parameters(list_texts_nxt, list_perplexity_nxt_with_error)

                estimated_rank = int(np.argmin(list_perplexity_nxt_with_error))
                words_nxt = list_words_nxt[estimated_rank]
                perplexity_nxt_with_error = list_perplexity_nxt_with_error[estimated_rank]
                neighbor_type = list_neighbor_type[estimated_rank]
                    
                if perplexity_nxt_with_error < perplexity_best + 2.0:
                    perplexity_nxt = self._calc_perplexity(n_idx, " ".join(words_nxt))
                else:
                    perplexity_nxt = perplexity_nxt_with_error

                iter_count += 1
                # If a candidate is strictly better, return it immediately
                if perplexity_nxt < perplexity_best:
                    stats.accepted[estimated_rank] += 1
                    return perplexity_nxt, words_nxt, [neighbor_type], max_depth
                
                # else, if < customized threshold, search deeper recursively
                elif perplexity_nxt < perplexity_best * depth_to_threshold[depth]:
                    search_order = list(range(batch_size))
                    random.shuffle(search_order)
                    for estimated_rank in search_order:
                        words_nxt = list_words_nxt[estimated_rank]
                        perplexity_nxt = list_perplexity_nxt_with_error[estimated_rank]
                        neighbor_type = list_neighbor_type[estimated_rank]

                        # track accepted & rejected
                        if (perplexity_nxt >= perplexity_best * depth_to_threshold[depth]):
                            stats.rejected[estimated_rank] += 1
                            continue
                        if tuple(words_nxt) in visited:
                            continue

                        stats.accepted[estimated_rank] += 1
                        perplexity_nxt, words_nxt, neighbor_types, max_depth_ = search(words_nxt, depth + 1)
                        max_depth = max(max_depth, max_depth_)

                        if perplexity_nxt is not None:
                            assert perplexity_nxt < perplexity_best
                            return (
                                perplexity_nxt,
                                words_nxt,
                                [neighbor_type] + neighbor_types,
                                max_depth,
                            )
                        
                # if algo does not produce enough candidates
                if iter_count >= iter_total:
                    # return None score improvements, None solution found, None perm meta, max_depth
                    return None, None, None, max_depth
                
                # updates progress bar
                if iter_count % print_every == 0:
                    print(
                        f"[Search] iteration:{iter_count} best:{perplexity_best:.6f}"
                        f" current:{perplexity_nxt or math.inf:.2f}"
                        f" neighbor:{neighbor_type}"
                        f" depth:{depth}"
                        f" {stats.summary()}"
                    )

                # if pbar is not None:
                #     pbar.update(1)
                
        perplexity_nxt, words_nxt, neighbor_types, max_depth = search(words_best)
        if perplexity_nxt is not None:
            assert perplexity_nxt < perplexity_best
            print(
                f"[Search] Update: {perplexity_best:.6f}"
                f" -> {perplexity_nxt:.2f},"
                f" neighbor:{','.join(map(str, neighbor_types))}"
                f" max_depth:{max_depth}"
                f" {stats.summary()}"
            )
            perplexity_best = perplexity_nxt
            words_best = words_nxt
        else:
            print(f"[Search] No update, max_depth:{max_depth} {stats.summary()}")

        return words_best, perplexity_best


    def ILS_kick(self, n_idx: int, words: list[str], n_kick: int = 2) -> tuple[list[str], list[int]]:
        """
        Apply kick to perturb the current solution, hopefully escaping local minimas in subsequent searches.

        This function first performs a structured block removal and reinsertion (if n_kick==2),
        then applies a series of random swaps based on the problem index-specific strength.
        This perturbation is used when hill climbing stagnates.

        Parameters
        ----------
        n_idx : int
            The problem index.
        words : list[str]
            The current candidate word list.
        n_kick : int, optional
            The intensity of the kick, by default 2.

        Returns
        -------
        tuple[list[str], list[int]]
            The perturbed word list and a list of neighbor type information indicating the moves.
        """
        words = words.copy()
        neighbor_types = []

        if n_kick == 2:
            length = 10
            left = random.randint(0, len(words) - length)
            right = left + length
            removed = words[left:right]
            words = words[:left] + words[right:]
            neighbor_type = [left]
            for word in removed:
                insert_idx = random.randint(0, len(words))
                words.insert(insert_idx, word)
                neighbor_type.append(insert_idx)
            neighbor_types.append(tuple(neighbor_type))

        # different kicking intensity for different samples
        strength = [2, 3, 3, 4, 5, 10] 
        for _ in range(n_kick * strength[n_idx]):
            r0 = random.randint(0, len(words) - 1)
            r1 = random.randint(0, len(words) - 1)
            words[r0], words[r1] = words[r1], words[r0]
            neighbor_types.append((r0, r1))
        return words, neighbor_types


    def run(self, list_idx_target: Optional[list[int]] = None, print_every: int = 100):
        """
        Main loop to run the optimization process and update/submissions.

        Continuously cycles over the target problem indices, performs hill climbing to
        improve the current solution, applies kicks if no improvement is observed,
        updates the best solutions, and periodically saves the model and score memos.

        Parameters
        ----------
        list_idx_target : Optional[list[int]], optional
            A list of problem indices to target; if None, all problems are processed.
        print_every : int, optional
            Number of iterations between status print updates, by default 10.
        """
        if list_idx_target is None:
            list_idx_target = list(range(NUM_SAMPLES))

        outer_iter = 0
        for n_idx in itertools.cycle(list_idx_target):
            free_memory()
            words_best, perplexity_best_old = self._get_best(n_idx)
            if outer_iter % print_every == 0:
                print("-"*100)
                print(f"[Step] sample_id: {n_idx} Prev: {perplexity_best_old:.6f}")

            # pbar = tqdm(total=500, mininterval=30)
            words_best, perplexity_best = self._hillclimbing(
                n_idx,
                words_best,
                perplexity_best_old,
                score_estimator=self.score_estimators[n_idx],
                iter_total=500,
                # pbar=pbar,
                print_every=print_every,
            )
            # pbar.close()
            if outer_iter % print_every == 0:
                print(f"    - [Step] sample_id: {n_idx} Current: {perplexity_best:.6f}")
                print("-"*100)

            # If hill climbing didn't improve, apply kick
            did_kick = False
            if perplexity_best_old == perplexity_best:
                if words_best == self._get_best_all(n_idx)[0]:
                    self.list_num_kick[n_idx] = 0

                # Decrement kick counter and reset if necessary
                # example: reset + 4 -> 3 -> 2 -> 1 -> reset + 4 -> 3 -> 2 -> 1 -> ...
                self.list_num_kick[n_idx] -= 1
                flag_reset = self.list_num_kick[n_idx] <= 0

                if flag_reset:
                    self.list_num_kick[n_idx] = random.randint(2, 3)
                n_kick = self.list_num_kick[n_idx]

                did_kick = True
                if flag_reset:
                    print("[Step] Reset words")
                    words_best = self._get_best_all(n_idx)[0]
                words_best, neighbor_types = self.ILS_kick(n_idx, words_best, n_kick=n_kick)
                print(f"[Step] Apply {n_kick} kicks: {neighbor_types}")
                perplexity_best = self._calc_perplexity(n_idx, " ".join(words_best))
            
            # Update current best solution
            self.list_words_best[n_idx] = words_best
            self.list_perplexity_best[n_idx] = perplexity_best
            self._update_best_all(n_idx, words_best, perplexity_best)

            # Save submission text if a significant improvement is found
            if not did_kick and perplexity_best < self._get_best_all(n_idx)[1] * 1.1:
                save_text(self._calc_perplexity, n_idx, " ".join(words_best), verbose=1)
            
            # Periodically save score memos and model checkpoints
            if time() > self.last_time_score_memo_saved + 1800:
                save_score_memo(self.score_memo, self.score_memo_with_error)
                self.last_time_score_memo_saved = time()
                self.score_estimators[n_idx].save_model()
            
            outer_iter += 1
                

## Run

In [None]:
# init optimizer
optimizer = Optimization(flag_use_best=False, flag_shuffle=True)  # do not use flag_use_best on 1st run

In [None]:
# run the main loop, adjust in the input to target different samples
# leave empty to run on all samples
optimizer.run(list_idx_target=[0])
# optimizer.run()

----------------------------------------------------------------------------------------------------
[Step] sample_id: 0 Prev: 2339.780733
[Search] Update: 2339.780733 -> 1109.56, neighbor:(0, 7, 8, 0) max_depth:0  accepted:[0, 0, 1, 0, 0, 0, 0, 0], rejected:[0, 0, 0, 0, 0, 0, 0, 0]     - total:1
    - [Step] sample_id: 0 Current: 1109.559923
----------------------------------------------------------------------------------------------------
score:1109.5599
[Search] Update: 1109.559923 -> 850.73, neighbor:(3, 5, 7, 0) max_depth:0  accepted:[1, 0, 0, 0, 0, 0, 0, 0], rejected:[0, 0, 0, 0, 0, 0, 0, 0]     - total:1
score:850.7291
[Search] Update: 850.729103 -> 847.41, neighbor:(4, 8, 9, 0) max_depth:0  accepted:[0, 0, 1, 0, 0, 0, 0, 0], rejected:[0, 0, 0, 0, 0, 0, 0, 0]     - total:1
score:847.4124
[Search] Update: 847.412425 -> 780.67, neighbor:(4, 9, 10, 0) max_depth:0  accepted:[1, 0, 0, 0, 0, 0, 0, 0], rejected:[0, 0, 0, 0, 0, 0, 0, 0]     - total:1
score:780.6729
[Search] Update: 780