# Santa 2024 - The Perplexity Permutation Puzzle

Minimizing the perplexity of given string

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import pandas as pd
import re
import math
import numpy as np
import random
from tqdm import tqdm
from heapq import heappush, heappop

random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
# damn Llama cannot do the rearrange job
# model_name = "/kaggle/input/llama-3.2/transformers/3b-instruct/1"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [3]:
"""Evaluation metric for Santa 2024."""
# https://www.kaggle.com/code/metric/santa-2024-metric/notebook
import gc
import os
from math import exp
from collections import Counter
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import transformers
import torch

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class ParticipantVisibleError(Exception):
    pass


def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str,
    model_path: str = '/kaggle/input/gemma-2/transformers/gemma-2-9b/2',
    load_in_8bit: bool = False,
    clear_mem: bool = False,
) -> float:
    """
    Calculates the mean perplexity of submitted text permutations compared to an original text.

    Parameters
    ----------
    solution : DataFrame
        DataFrame containing the original text in a column named 'text'.
        Includes a row ID column specified by `row_id_column_name`.

    submission : DataFrame
        DataFrame containing the permuted text in a column named 'text'.
        Must have the same row IDs as the solution.
        Includes a row ID column specified by `row_id_column_name`.

    row_id_column_name : str
        Name of the column containing row IDs.
        Ensures aligned comparison between solution and submission.

    model_path : str, default='/kaggle/input/gemma-2/transformers/gemma-2-9b/2'
        Path to the serialized LLM.

    load_in_8bit : bool, default=False
        Use 8-bit quantization for the model. Requires CUDA.

    clear_mem : bool, default=False
        Clear GPU memory after scoring by clearing the CUDA cache.
        Useful for testing.

    Returns
    -------
    float
        The mean perplexity score. Lower is better.

    Raises
    ------
    ParticipantVisibleError
        If the submission format is invalid or submitted strings are not valid permutations.

    Examples
    --------
    >>> import pandas as pd
    >>> model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
    >>> solution = pd.DataFrame({
    ...     'id': [0, 1],
    ...     'text': ["this is a normal english sentence", "the quick brown fox jumps over the lazy dog"]
    ... })
    >>> submission = pd.DataFrame({
    ...     'id': [0, 1],
    ...     'text': ["sentence english normal a is this", "lazy the over jumps fox brown quick the dog"]
    ... })
    >>> score(solution, submission, 'id', model_path=model_path, clear_mem=True) > 0
    True
    """
    # Check that each submitted string is a permutation of the solution string
    sol_counts = solution.loc[:, 'text'].str.split().apply(Counter)
    sub_counts = submission.loc[:, 'text'].str.split().apply(Counter)
    invalid_mask = sol_counts != sub_counts
    if invalid_mask.any():
        raise ParticipantVisibleError(
            'At least one submitted string is not a valid permutation of the solution string.'
        )

    # Calculate perplexity for the submitted strings
    sub_strings = [
        ' '.join(s.split()) for s in submission['text'].tolist()
    ]  # Split and rejoin to normalize whitespace
    scorer = PerplexityCalculator(
        model_path=model_path,
        load_in_8bit=load_in_8bit,
    )  # Initialize the perplexity calculator with a pre-trained model
    perplexities = scorer.get_perplexity(
        sub_strings
    )  # Calculate perplexity for each submitted string

    if clear_mem:
        # Just move on if it fails. Not essential if we have the score.
        try:
            scorer.clear_gpu_memory()
        except:
            print('GPU memory clearing failed.')

    return float(np.mean(perplexities))


class PerplexityCalculator:
    """
    Calculates perplexity of text using a pre-trained language model.

    Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py

    Parameters
    ----------
    model_path : str
        Path to the pre-trained language model

    load_in_8bit : bool, default=False
        Use 8-bit quantization for the model. Requires CUDA.

    device_map : str, default="auto"
        Device mapping for the model.
    """

    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = False,
        device_map: str = 'auto',
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires CUDA device')
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.eval()

    def get_perplexity(
        self, input_texts: Union[str, List[str]], debug=False
    ) -> Union[float, List[float]]:
        """
        Calculates the perplexity of given texts.

        Parameters
        ----------
        input_texts : str or list of str
            A single string or a list of strings.

        batch_size : int, default=None
            Batch size for processing. Defaults to the number of input texts.

        debug : bool, default=False
            Print debugging information.

        Returns
        -------
        float or list of float
            A single perplexity value if input is a single string,
            or a list of perplexity values if input is a list of strings.

        Examples
        --------
        >>> import pandas as pd
        >>> model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
        >>> scorer = PerplexityCalculator(model_path=model_path)

        >>> submission = pd.DataFrame({
        ...     'id': [0, 1, 2],
        ...     'text': ["this is a normal english sentence", "thsi is a slihgtly misspelled zr4g sentense", "the quick brown fox jumps over the lazy dog"]
        ... })
        >>> perplexities = scorer.get_perplexity(submission["text"].tolist())
        >>> perplexities[0] < perplexities[1]
        True
        >>> perplexities[2] < perplexities[0]
        True

        >>> perplexities = scorer.get_perplexity(["this is a sentence", "another sentence"])
        >>> all(p > 0 for p in perplexities)
        True

        >>> scorer.clear_gpu_memory()
        """
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        with torch.no_grad():
            # Process each sequence independently
            for text in input_texts:
                # Explicitly add sequence boundary tokens to the text
                text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"

                # Tokenize
                model_inputs = self.tokenizer(
                    text_with_special,
                    return_tensors='pt',
                    add_special_tokens=False,
                )

                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                # Get model output
                output = self.model(**model_inputs, use_cache=False)
                logits = output['logits']

                # Shift logits and labels for calculating loss
                shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                shift_labels = model_inputs['input_ids'][..., 1:].contiguous()  # Drop first input

                # Calculate token-wise loss
                loss = self.loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1)
                )

                # Calculate average loss
                sequence_loss = loss.sum() / len(loss)
                loss_list.append(sequence_loss.cpu().item())

                # Debug output
                if debug:
                    print(f"\nProcessing: '{text}'")
                    print(f"With special tokens: '{text_with_special}'")
                    print(f"Input tokens: {model_inputs['input_ids'][0].tolist()}")
                    print(f"Target tokens: {shift_labels[0].tolist()}")
                    print(f"Input decoded: {self.tokenizer.decode(model_inputs['input_ids'][0])}")
                    print(f"Target decoded: {self.tokenizer.decode(shift_labels[0])}")
                    print(f"Individual losses: {loss.tolist()}")
                    print(f"Average loss: {sequence_loss.item():.4f}")

        ppl = [exp(i) for i in loss_list]

        if debug:
            print("\nFinal perplexities:")
            for text, perp in zip(input_texts, ppl):
                print(f"Text: '{text}'")
                print(f"Perplexity: {perp:.2f}")

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        """Clears GPU memory by deleting references and emptying caches."""
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, 'model'):
            del self.model
        if hasattr(self, 'tokenizer'):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()

In [4]:
scorer = PerplexityCalculator("/kaggle/input/gemma-2/transformers/gemma-2-9b/2")

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [5]:
import random
import math


def get_neighbor(sequence, vocabulary):
    new_sequence = sequence[:]
    idx1, idx2 = random.sample(range(len(sequence)), 2)
    new_sequence[idx1], new_sequence[idx2] = new_sequence[idx2], new_sequence[idx1]
    return new_sequence

def simulated_annealing(vocabulary, sequence_length, initial_temperature, cooling_rate, num_iterations):
    # random.shuffle(vocabulary)
    current_sequence = vocabulary
    current_perplexity = scorer.get_perplexity(" ".join(current_sequence))
    best_sequence = current_sequence
    best_perplexity = current_perplexity
    
    temperature = initial_temperature
    
    for iteration in range(num_iterations):
        neighbor_sequence = get_neighbor(current_sequence, vocabulary)
        neighbor_perplexity = scorer.get_perplexity(" ".join(neighbor_sequence))
        
        if (neighbor_perplexity < current_perplexity or
            math.exp(-(neighbor_perplexity - current_perplexity) / temperature)*10 > random.random()):
            current_sequence = neighbor_sequence
            current_perplexity = neighbor_perplexity
        
        if neighbor_perplexity < best_perplexity:
            best_sequence = neighbor_sequence
            best_perplexity = neighbor_perplexity
        
        temperature *= cooling_rate
        if iteration%500 == 0:
            print(f"Iteration {iteration}: Best Perplexity {best_perplexity}, Current Perplexity {current_perplexity}, Temperature {temperature}")

    return best_sequence, best_perplexity

# string = 'sleigh of the magi yuletide cheer is unwrap gifts and eat cheer holiday decorations holly jingle relax sing carol visit workshop grinch naughty nice chimney stocking ornament nutcracker polar beard'
# sequence_length = len(string.split(" "))
# initial_temperature = 100.0
# cooling_rate = 0.97
# num_iterations = 3000
# vocabulary = string.split(" ")
# for i in range(3):
#     best_sequence, best_perplexity = simulated_annealing(vocabulary, sequence_length, initial_temperature, cooling_rate, num_iterations)
#     print(f"Best sequence: {' '.join(best_sequence)} with perplexity {best_perplexity}")

In [6]:
# 'sleigh of the magi yuletide cheer is unwrap gifts and eat cheer holiday decorations holly jingle relax sing carol visit workshop grinch naughty nice chimney stocking ornament nutcracker polar beard', 
# word_list= ['sleigh of the magi',  'yuletide','cheer', 'is','unwrap gifts','and','eat',
#             'cheer holiday decorations', 'holly jingle', 'relax', 'sing carol','visit workshop',
#             'grinch','naughty nice chimney','stocking ornament','nutcracker','polar beard']
# maxn=10000
# ans=''
# for i in tqdm(range(10000)):
#     random.shuffle(word_list)
#     score=scorer.get_perplexity(" ".join(word_list))
#     if score<maxn:
#         maxn=score
#         ans=' '.join(word_list)
#     if(i%500==0):
#         print(f"{maxn},{ans}")
# print(maxn)
# print(ans)

In [7]:
def heuristic_rearrange(words, scorer, beam_size=5):
    sequences = [([], words)]
    cache = {}

    for _ in range(len(words)):
        all_candidates = []
        for seq, remaining in sequences:
            candidates = []
            candidate_texts = []
            candidate_data = []

            for i, word in enumerate(remaining):
                new_seq = seq + [word]
                new_remaining = remaining[:i] + remaining[i+1:]
                seq_key = ' '.join(new_seq)

                if seq_key in cache:
                    perplexity = cache[seq_key]
                    all_candidates.append((perplexity, new_seq, new_remaining))
                else:
                    candidate_texts.append(seq_key)
                    candidate_data.append((new_seq, new_remaining, seq_key))

            if candidate_texts:
                perplexities = scorer.get_perplexity(candidate_texts)
                perplexities = list(map(int, perplexities))
                for (candidate_seq, candidate_remaining, seq_key), perplexity in zip(candidate_data, perplexities):
                    cache[seq_key] = perplexity
                    all_candidates.append((perplexity, candidate_seq, candidate_remaining))

        ordered = sorted(all_candidates, key=lambda x: x[0])
        sequences = [(seq, rem) for _, seq, rem in ordered[:beam_size]]

    best_seq = sequences[0][0]
    return ' '.join(best_seq)

In [8]:
# This is the heuristic_rearrange 5 solution
math = [
    'reindeer mistletoe scrooge gingerbread elf fireplace chimney family advent ornament',
    'reindeer mistletoe scrooge and the gingerbread family fireplace elf night walk advent ornament chimney bake sleep laugh jump drive give',
    'jingle yuletide carol grinch nutcracker holiday decorations ornament stocking gifts naughty nice holly cheer sleigh beard chimney workshop magi polar',
    'yuletide cheer cheer and sing of the carol grinch nutcracker holiday decorations gifts stocking unwrap ornament jingle sleigh holly nice naughty chimney visit beard workshop polar relax eat is magi',
    'eggnog cookie poinsettia fruitcake chocolate peppermint candy snowglobe wreath and star candle angel card paper doll game night in with the hohoho season of joy greeting from you to we hope that wish it merry have peace wonder believe dream not as wrapping bow toy workshop fireplace milk kaggle puzzle',
    'eggnog yuletide scrooge mistletoe nutcracker poinsettia gingerbread cookie fruitcake holly wreath holiday ornament snowglobe candle fireplace stocking fireplace chimney reindeer elf toy sleigh gifts candy peppermint chocolate ornament decorations advent season family merry and grinch joy peace cheer cheer and hope greeting card and wrapping paper wish you the wonder of the night star night in the dream of it is nice to believe that we have not sleep with angel visit carol sing jingle unwrap give eat bake laugh walk drive jump hohoho from chimney doll workshop workshop puzzle game naughty milk beard polar bow relax as magi kaggle']

In [9]:
current_best = [
    'reindeer mistletoe elf gingerbread family advent scrooge chimney fireplace ornament',
    'reindeer sleep walk the night and drive mistletoe scrooge laugh chimney jump elf bake gingerbread family give advent fireplace ornament', 
    'sleigh yuletide beard carol cheer chimney decorations gifts grinch holiday holly jingle magi naughty nice nutcracker ornament polar workshop stocking', 
    'sleigh of the magi yuletide cheer is unwrap gifts and eat cheer holiday decorations holly jingle relax sing carol visit workshop grinch naughty nice chimney stocking ornament nutcracker polar beard', 
    'from and of to the as in that it we with not you have milk chocolate candy peppermint eggnog cookie fruitcake toy doll game puzzle greeting card wrapping paper bow candle fireplace wreath poinsettia snowglobe angel star wish dream night season wonder believe hope joy peace merry hohoho kaggle workshop',
    'from and and as we and have the in is it of not that the to with you advent card angel bake beard believe bow candy candle carol cheer cheer chocolate chimney cookie decorations doll dream drive eat eggnog family fireplace fireplace chimney fruitcake game gifts give gingerbread greeting grinch holiday holly hohoho hope jingle jump joy kaggle laugh magi merry milk mistletoe naughty nice night night elf nutcracker ornament ornament of the wrapping paper peace peppermint polar poinsettia puzzle reindeer relax scrooge season sing sleigh sleep snowglobe star stocking toy unwrap visit walk wish wonder workshop workshop wreath yuletide']
for row in current_best:
    print(scorer.get_perplexity(row))

468.96133548013836
421.72883862887016
297.4792420976342
200.61770629462436
71.09958344833998
34.59608450070472


In [10]:
def rearrange_words(words,best_score, max_iterations=20000):
    """What I can do, let's go random"""
    min_perplexity = float('inf')
    best_sentence = ''
    tried_permutations = set()
    words = words.split(" ")
    # print(len(words))
    for _ in range(max_iterations):
        # Generate a random permutation
        permuted_words = tuple(torch.randperm(len(words)).tolist())
        if permuted_words in tried_permutations:
            continue
        tried_permutations.add(permuted_words)

        sentence = ' '.join([words[i] for i in permuted_words])
        perplexity = scorer.get_perplexity(sentence)
        if perplexity < best_score:
            print(sentence)
            print(perplexity)
        if perplexity < min_perplexity:
            min_perplexity = perplexity
            best_sentence = sentence

    return best_sentence, min_perplexity

In [11]:
submission = pd.read_csv('/kaggle/input/santa-2024/sample_submission.csv')

results = {'id': [], 'text': []}

for idx, row in tqdm(submission.iterrows(), total=submission.shape[0]):
    text_id = row['id']
    scrambled_text = row['text']
    # if text_id < 3:
    #     continue
    rearranged_text_math = heuristic_rearrange(scrambled_text.split(" "), scorer, 1)
    rearranged_text_best = current_best[idx]
    # best_score = scorer.get_perplexity(rearranged_text_best)

    print(rearranged_text_math)
    best_score = scorer.get_perplexity(rearranged_text_best)
    # rearranged_text_random, random_score = rearrange_words(scrambled_text, best_score)
    # print(rearranged_text_random)
    min_score = float('inf')
    math_score = scorer.get_perplexity(rearranged_text_math)
    best_score = scorer.get_perplexity(rearranged_text_best)
    if math_score < min_score:
        min_score = math_score
        rearranged_text = rearranged_text_math
    # if random_score < min_score:
    #     min_score = random_score
    #     rearranged_text = rearranged_text_random
    if best_score < min_score:
        min_score = best_score
        rearranged_text = rearranged_text_best # I made a mistake here :(
    print(math_score)
    print(best_score)
    # print(random_score)
    results['id'].append(text_id)
    results['text'].append(rearranged_text)

  0%|          | 0/6 [00:00<?, ?it/s]

scrooge mistletoe ornament family advent fireplace chimney elf reindeer gingerbread


 17%|█▋        | 1/6 [00:06<00:30,  6.06s/it]

1339.011593645358
468.96133548013836
scrooge mistletoe ornament and reindeer family advent fireplace chimney elf night sleep the gingerbread bake walk drive give laugh jump


 33%|███▎      | 2/6 [00:28<01:03, 15.82s/it]

1466.7430622322397
421.72883862887016
yuletide gifts grinch ornament nutcracker decorations holiday stocking holly jingle sleigh carol cheer chimney naughty nice beard workshop polar magi


 50%|█████     | 3/6 [00:51<00:57, 19.19s/it]

962.1704775691508
297.4792420976342
yuletide gifts unwrap holiday cheer the nutcracker and grinch decorations ornament stocking holly jingle sleigh carol sing cheer of chimney visit naughty nice beard eat relax polar workshop is magi


 67%|██████▋   | 4/6 [01:44<01:04, 32.33s/it]

761.7996057572035
200.61770629462436
eggnog fruitcake poinsettia snowglobe wreath candle cookie star candy peppermint chocolate milk hohoho merry and joy peace hope angel wish dream believe wonder night season greeting card paper wrapping bow toy doll game puzzle fireplace to the from of you we with in it that as not have workshop kaggle


 83%|████████▎ | 5/6 [04:12<01:14, 74.11s/it]

282.7074217775683
71.09958344833998
eggnog yuletide poinsettia mistletoe fruitcake scrooge nutcracker snowglobe gingerbread cookie holly wreath ornament reindeer sleigh elf stocking candy peppermint holiday ornament star angel carol candle advent fireplace fireplace chimney chimney decorations gifts wrapping paper bow card greeting merry hohoho joy peace cheer cheer season wish hope dream believe family and you and and to from with in the the the of of is that it not have as we give eat sleep sing laugh walk jump drive visit bake unwrap jingle relax wonder night night nice naughty grinch toy doll game puzzle chocolate milk beard polar workshop workshop kaggle magi


100%|██████████| 6/6 [14:45<00:00, 147.53s/it]

109.24447458721197
34.59608450070472





In [12]:
# save to submission.csv
output_df = pd.DataFrame(results)
output_df.head()
output_df.to_csv('submission.csv', index=False)