In [3]:
"""Evaluation metric for Santa 2024."""

import gc
import os
from math import exp
from collections import Counter
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import transformers
import torch

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class ParticipantVisibleError(Exception):
    pass


def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str,
    model_path: str = '/kaggle/input/gemma-2/transformers/gemma-2-9b/2',
    load_in_8bit: bool = False,
    clear_mem: bool = False,
) -> float:
    """
    Calculates the mean perplexity of submitted text permutations compared to an original text.

    Parameters
    ----------
    solution : DataFrame
        DataFrame containing the original text in a column named 'text'.
        Includes a row ID column specified by `row_id_column_name`.

    submission : DataFrame
        DataFrame containing the permuted text in a column named 'text'.
        Must have the same row IDs as the solution.
        Includes a row ID column specified by `row_id_column_name`.

    row_id_column_name : str
        Name of the column containing row IDs.
        Ensures aligned comparison between solution and submission.

    model_path : str, default='/kaggle/input/gemma-2/transformers/gemma-2-9b/2'
        Path to the serialized LLM.

    load_in_8bit : bool, default=False
        Use 8-bit quantization for the model. Requires CUDA.

    clear_mem : bool, default=False
        Clear GPU memory after scoring by clearing the CUDA cache.
        Useful for testing.

    Returns
    -------
    float
        The mean perplexity score. Lower is better.

    Raises
    ------
    ParticipantVisibleError
        If the submission format is invalid or submitted strings are not valid permutations.
    """
    # Check that each submitted string is a permutation of the solution string
    sol_counts = solution.loc[:, 'text'].str.split().apply(Counter)
    sub_counts = submission.loc[:, 'text'].str.split().apply(Counter)
    invalid_mask = sol_counts != sub_counts
    if invalid_mask.any():
        raise ParticipantVisibleError(
            'At least one submitted string is not a valid permutation of the solution string.'
        )

    # Calculate perplexity for the submitted strings
    sub_strings = [
        ' '.join(s.split()) for s in submission['text'].tolist()
    ]  # Split and rejoin to normalize whitespace
    scorer = PerplexityCalculator(
        model_path=model_path,
        load_in_8bit=load_in_8bit,
    )  # Initialize the perplexity calculator with a pre-trained model
    perplexities = scorer.get_perplexity(
        sub_strings
    )  # Calculate perplexity for each submitted string

    if clear_mem:
        # Just move on if it fails. Not essential if we have the score.
        try:
            scorer.clear_gpu_memory()
        except:
            print('GPU memory clearing failed.')

    return float(np.mean(perplexities))


class PerplexityCalculator:
    """
    Calculates perplexity of text using a pre-trained language model.

    Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py

    Parameters
    ----------
    model_path : str
        Path to the pre-trained language model

    load_in_8bit : bool, default=False
        Use 8-bit quantization for the model. Requires CUDA.

    device_map : str, default="auto"
        Device mapping for the model.
    """

    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = False,
        device_map: str = 'auto',
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires CUDA device')
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.eval()

    def get_tokenizer(self):
        return self.tokenizer



    def get_perplexity(
        self, input_texts: Union[str, List[str]], debug=False
    ) -> Union[float, List[float]]:
        """
        Calculates the perplexity of given texts.

        Parameters
        ----------
        input_texts : str or list of str
            A single string or a list of strings.

        batch_size : int, default=None
            Batch size for processing. Defaults to the number of input texts.

        debug : bool, default=False
            Print debugging information.

        Returns
        -------
        float or list of float
            A single perplexity value if input is a single string,
            or a list of perplexity values if input is a list of strings.
        """
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        with torch.no_grad():
            # Process each sequence independently
            for text in input_texts:
                # Explicitly add sequence boundary tokens to the text
                text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"

                # Tokenize
                model_inputs = self.tokenizer(
                    text_with_special,
                    return_tensors='pt',
                    add_special_tokens=False,
                )

                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                # Get model output
                output = self.model(**model_inputs, use_cache=False)
                logits = output['logits']

                # Shift logits and labels for calculating loss
                shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                shift_labels = model_inputs['input_ids'][..., 1:].contiguous()  # Drop first input

                # Calculate token-wise loss
                loss = self.loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1)
                )

                # Calculate average loss
                sequence_loss = loss.sum() / len(loss)
                loss_list.append(sequence_loss.cpu().item())

                # Debug output
                if debug:
                    print(f"\nProcessing: '{text}'")
                    print(f"With special tokens: '{text_with_special}'")
                    print(f"Input tokens: {model_inputs['input_ids'][0].tolist()}")
                    print(f"Target tokens: {shift_labels[0].tolist()}")
                    print(f"Input decoded: {self.tokenizer.decode(model_inputs['input_ids'][0])}")
                    print(f"Target decoded: {self.tokenizer.decode(shift_labels[0])}")
                    print(f"Individual losses: {loss.tolist()}")
                    print(f"Average loss: {sequence_loss.item():.4f}")

        ppl = [exp(i) for i in loss_list]

        if debug:
            print("\nFinal perplexities:")
            for text, perp in zip(input_texts, ppl):
                print(f"Text: '{text}'")
                print(f"Perplexity: {perp:.2f}")

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        """Clears GPU memory by deleting references and emptying caches."""
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, 'model'):
            del self.model
        if hasattr(self, 'tokenizer'):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()



In [4]:
import pandas as pd
model_path = "./model/gemma"
scorer = PerplexityCalculator(model_path=model_path, load_in_8bit=False)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


## Example 
``` py
submission = pd.DataFrame({
'id': [0, 1, 2],
'text': ["this is a normal english sentence", "thsi is a slihgtly misspelled zr4g sentense", "the quick brown fox jumps over the lazy dog"]
})
perplexities = scorer.get_perplexity(submission["text"].tolist(), debug=True)

sol_counts = submission.loc[:, 'text'].str.split().apply(Counter)
sol_counts


```

In [5]:
#sentence = "thsi is a slihgtly misspelled zr4g sentense"
sentence = "advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge"
#sentence = "A christmas sentence will appear after the dot. ornament elf gingerbread"
#sentence = "This is a sentence: I am from Peru"
#sentence = "advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle"
#sentence = "advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and"
encoded = scorer.tokenizer.encode(sentence, add_special_tokens=False)
tokens = scorer.tokenizer.convert_ids_to_tokens(encoded)

In [6]:
data_input = pd.read_csv("./Input/sample_submission.csv")
Iterative_sentences = data_input['text'].tolist()

In [7]:
scorer.tokenizer.decode(136507)

' gingerbread'

# Actual attempt

In [8]:
import torch.nn.functional as F
import math

def get_dictionaries(tokens, encoded):
    dictionary = {}
    init_tokens = {}
    words = []
    current_word = ""
    idx = 0
    corr_enc = []
    for i, token in enumerate(tokens):
        if token.startswith('▁') or idx==0:
            idx += 1

            if current_word:
                words.append(current_word)
                dictionary[current_word] = corr_enc[len(words) - 1]
            # Remove the '▁' and start a new word
            if idx != 1:
                current_word = token[1:]
            else: 
                current_word = token[0:]
            corr_enc.append([encoded[i]])
            init_tokens[encoded[i]] = current_word
        else:
            # Continue the existing word
            corr_enc[len(words)].append(encoded[i])
            current_word += token

    # After the loop ends, add the last word if it exists
    if current_word:
        words.append(current_word)
        dictionary[current_word] = corr_enc[len(words) - 1]
    return dictionary, init_tokens, words

In [9]:
def get_tuples(dictionary):
    available_words = list(dictionary.keys())

    ppx_per_tuple = {}

    for word in available_words:
        remaining_words = [w for w in available_words if w != word]
        ppx_per_tuple[word] = {}
        for w in remaining_words:
            tuple_score = scorer.get_perplexity(f"A christmas sentence will appear after the dot. {word} {w}")
            ppx_per_tuple[word][w] = tuple_score
    #sort the dictionary
    ppx_per_tuple = {k: v for k, v in sorted(ppx_per_tuple.items(), key=lambda item: item[0])}

    #sort sub dictionary for values
    for key in ppx_per_tuple:
        ppx_per_tuple[key] = {k: v for k, v in sorted(ppx_per_tuple[key].items(), key=lambda item: item[1])}
    return ppx_per_tuple

In [10]:
def find_lowest_value_sentence(data):
    returns = []
    all_words = set(data.keys())
    
    for word in all_words:
        sentence = []
        current_word = word
        remaining_words = set([w for w in all_words if w != word])
        sentence.append(current_word)

        while len(remaining_words) > 0:
            for k, v in data[current_word].items():
                if k not in sentence:
                    sentence.append(k)
                    current_word = k
                    remaining_words.remove(k)
                    break
        returns.append(" ".join(sentence))
    return returns

In [13]:
output = []

for sen in Iterative_sentences:
    encoded = scorer.tokenizer.encode(sen, add_special_tokens=False)
    tokens = scorer.tokenizer.convert_ids_to_tokens(encoded)
    dictionary, init_tokens, words = get_dictionaries(tokens, encoded)
    ppx_per_tuple = get_tuples(dictionary)
    sentence = find_lowest_value_sentence(ppx_per_tuple)

    solution_final = ""
    score = float("inf")
    for sol in sentence:
        solutions = pd.DataFrame(
            {'id': [0],
            'text': sol})

        perplexities = scorer.get_perplexity(solutions["text"].tolist(), debug=False)
        if perplexities[0] < score:
            score = perplexities[0]
            solution_final = sol
    print(solution_final)
    output.append(solution_final)

ornament scrooge mistletoe reindeer elf family advent gingerbread chimney fireplace
reindeer scrooge mistletoe and the night elf family ornament gingerbread bake advent walk jump drive give laugh sleep chimney fireplace
nutcracker yuletide cheer grinch holiday decorations ornament gifts stocking carol holly jingle sleigh polar beard nice naughty chimney workshop magi
sleigh yuletide cheer unwrap the nutcracker grinch holiday decorations ornament gifts and of is nice carol sing jingle beard holly stocking chimney visit naughty polar magi workshop relax eat
puzzle game poinsettia eggnog fruitcake snowglobe hohoho merry and you have to it star night of joy peace hope that the season greeting card from kaggle in as we wish not with candy peppermint chocolate milk cookie candle wreath bow toy doll paper angel dream believe wonder workshop fireplace wrapping


KeyboardInterrupt: 

# Priority queue not optimal

In [1]:
sentences = ["ornament scrooge mistletoe reindeer elf family advent gingerbread chimney fireplace"
,"reindeer scrooge mistletoe and the night elf family ornament gingerbread bake advent walk jump drive give laugh sleep chimney fireplace"
,"nutcracker yuletide cheer grinch holiday decorations ornament gifts stocking carol holly jingle sleigh polar beard nice naughty chimney workshop magi"
,"sleigh yuletide cheer unwrap the nutcracker grinch holiday decorations ornament gifts and of is nice carol sing jingle beard holly stocking chimney visit naughty polar magi workshop relax eat"
,"puzzle game poinsettia eggnog fruitcake snowglobe hohoho merry and you have to it star night of joy peace hope that the season greeting card from kaggle in as we wish not with candy peppermint chocolate milk cookie candle wreath bow toy doll paper angel dream believe wonder workshop fireplace wrapping"
]

In [11]:
solutions = pd.DataFrame(
            {'id': [x for x in range(len(sentences))],
            'text': sentences})

perplexities = scorer.get_perplexity(solutions["text"].tolist(), debug=False)

In [12]:
perplexities

[1083.8569951461084,
 1815.118892385283,
 688.941423011762,
 867.5082310343188,
 479.0835461252901]

In [35]:
def get_prob(tokens: list, model, new_sentence):
    new_sent = new_sentence.copy()
    total_log_prob = 1.0
    for tok in tokens: # p(scroge) = p(sc) * p(roo) * p(ge)
        w = scorer.tokenizer.decode(tok)
        text_with_special = f"{scorer.tokenizer.bos_token}{' '.join(new_sentence)}{w}{scorer.tokenizer.eos_token}"

        minputs = scorer.tokenizer(text_with_special, return_tensors='pt', add_special_tokens=False,)
        
        outputs = scorer.model(**minputs)
        logits = outputs.logits
        shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
        shift_labels = minputs['input_ids'][..., 1:].contiguous()  # Drop first input

        #applying softmax only to the 9+i token
        probs = torch.softmax(shift_logits[0][-1], dim=0)

        token_prob = probs[tok].item()

        total_log_prob *= token_prob

        # Append the chosen token to the current_ids for the next iteration
        new_sent.append(w)
    del minputs
    del outputs
    del logits
    del shift_logits
    del shift_labels
    return total_log_prob

In [44]:
from tqdm import tqdm
import heapq
import pandas as pd
import math
import logging

solutions = []
words = list(dictionary.keys())

priority_queue = []

new_sentence = sentence.split()[:8]
words_to_use = set(list(dictionary.keys())[8:])
desire_sentence = set(encoded[9:])

#save log to file
logging.basicConfig(filename='search.log', level=logging.INFO)

#start log
logging.info("Starting search")

# Initial state
heapq.heappush(priority_queue, (1.0, words_to_use.copy(), new_sentence.copy(), desire_sentence.copy()))

while priority_queue:
    # Pop the state with the smallest cost
    cost, words_left, current_sentence, desired_left = heapq.heappop(priority_queue)

    # If we've used all required words, we are done
    if len(desired_left) == 0:
        # Store the solution and break if desired
        solutions.append(" ".join(current_sentence))
        break

    # If no words left to use but still desired words remain, can't continue
    if len(words_left) == 0:
        continue

    # Compute probabilities for all candidate next words
    probs = {}
    for key in words_left:
        probs[key] = get_prob(dictionary[key], scorer.model, current_sentence)

    # For each candidate word, create a new state
    coef = -(len(encoded[9:]) - len(desired_left) + 3) 

    iter = True
    for w in probs:
        if (probs[w] * abs(cost)) > 10**coef:
            iter = False   
            new_sentence_state = current_sentence.copy()
            new_sentence_state.append(w)

            new_desired = desired_left.copy()
            new_desired -= set(dictionary[w])  # Remove words contributed by w

            new_words_left = words_left.copy()
            new_words_left.remove(w)

            # If we've achieved the desired sentence, we can record this solution
            if len(new_desired) == 0:

                solutions.append([" ".join(new_sentence_state[8:]), - (probs[w] * abs(cost))])
                logging.info(f"Found sentence: {' '.join(new_sentence_state)}")
                # Depending on your strategy, you might break here or continue searching
                break

            # Push the new state into the priority queue
            # Here cost is probs[w]. If you want to prioritize higher probabilities,
            # you could use -probs[w].
            logging.info(f"Current sentence: {' '.join(new_sentence_state)}, Desired: {desired_left}, Words left: {words_left}, Cost: {- (probs[w] * abs(cost))}")
            
            heapq.heappush(priority_queue, (- (probs[w] * abs(cost)), new_words_left, new_sentence_state, new_desired))
    if iter:
        #top 1 word
        w = max(probs, key=probs.get)
        new_sentence_state = current_sentence.copy()
        new_sentence_state.append(w)

        new_desired = desired_left.copy()
        new_desired -= set(dictionary[w])

        new_words_left = words_left.copy()
        new_words_left.remove(w)

        if len(new_desired) == 0:
            solutions.append(" ".join(new_sentence_state))
            logging.info(f"Found sentence: {' '.join(new_sentence_state)}")
            break

        logging.info(f"Current sentence: {' '.join(new_sentence_state)}, Desired: {desired_left}, Words left: {words_left}, Cost: {- (probs[w] * abs(cost))}")

        heapq.heappush(priority_queue, (- (probs[w] * abs(cost)), new_words_left, new_sentence_state, new_desired))
# Convert solutions to a dataframe if needed
#sort solutions by cost
solutions = sorted(solutions, key=lambda x: x[1])
solutions = [x[0] for x in solutions]
solutions_df = pd.DataFrame(
    {'id': range(len(solutions)),
     'text': solutions}
)


In [None]:
from tqdm import tqdm
solutions = []
words = list(dictionary.keys())

for j in range(len(words) - 8):
    new_sentence = sentence.split()[:8] + [words[8+j]]
    words_to_use = set(list(dictionary.keys())[8:]) - set([words[8+j]])
    init_tokens_ = set(list(init_tokens.keys())[8:]) - set(corr_enc[8+j])
    desire_sentence = set(encoded[9:]) - set(dictionary[words[8+j]])
    for i in tqdm(range(len(words) - 10)):
        
        probs = {}

        for key in words_to_use:
            probs[key] = get_prob(dictionary[key], scorer.model, new_sentence)

        #get index of highest probability
        best_word = max(probs, key=probs.get)

        new_sentence.append(best_word)

        desire_sentence -= set(dictionary[best_word])
        words_to_use -= set([best_word])
        if len(desire_sentence) == 0:
            break
    solutions.append(' '.join(new_sentence[8:]))

solutions = pd.DataFrame(
    {'id': range(len(solutions)),
    'text': solutions})
        

In [None]:
from tqdm import tqdm
solutions = []
words = list(dictionary.keys())
for j in range(len(words) - 8):
    new_sentence = sentence.split()[:8] + [words[8+j]]
    init_tokens_ = set(list(init_tokens.keys())[8:]) - set(corr_enc[8+j])
    desire_sentence = set(encoded[9:]) - set(dictionary[words[8+j]])
    for i in tqdm(range(len(words) - 9)):
        text_with_special = f"{scorer.tokenizer.bos_token}{' '.join(new_sentence)}{scorer.tokenizer.eos_token}"
        minputs = scorer.tokenizer(text_with_special, return_tensors='pt', add_special_tokens=False,)
        
        outputs = scorer.model(**minputs)
        logits = outputs.logits
        shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
        shift_labels = minputs['input_ids'][..., 1:].contiguous()  # Drop first input

        #applying softmax only to the 9+i token
        #shift_logits[0][9+i] = F.softmax(shift_logits[0][9+i], dim=0)

        for key, val in dictionary.items():
            # val is array of tokens. do shift_logits[0][9+i][v] = prod (shift_logits[0][9+i][v] for v in val)
            product_value = torch.prod(shift_logits[0][9+i][val])
            for v in val:
                shift_logits[0][9+i][v] = product_value
                
        max_index = torch.topk(shift_logits[0][9+i], 256000).indices #beggining of the sentence


        

        # Find the lower index of all the words in desire_sentence
        index = []
        if len(desire_sentence) == 0:
            break
        
        tok = 0
        for enc in max_index:
            if enc.item() in init_tokens_:
                tok = enc.item()
                init_tokens_ = init_tokens_ - set([tok])
                break
        #find the complete word in dictionary given the token
        for word in dictionary:
            if tok in dictionary[word]:
                new_sentence = new_sentence + [word]
                desire_sentence = desire_sentence - set(dictionary[word])
                break

        del minputs
        del outputs
        del logits
        del shift_logits
        del shift_labels

    solutions.append(' '.join(new_sentence[8:]))

solutions = pd.DataFrame(
    {'id': range(len(solutions)),
    'text': solutions})
        

In [9]:
solutions = pd.DataFrame(
    {'id': [x for x in range(len(sentence))],
    'text': sentence})

# gingerbread ornament elf -> 105906.56265800883
# gingerbread elf ornament -> 38657.65136955225
# ornament gingerbread elf -> 182990.1307424248
# ornament elf gingerbread -> 143630.59930807285
# elf gingerbread ornament -> 9608176.378154187
# elf ornament gingerbread -> 27371147.346616127

perplexities = scorer.get_perplexity(solutions["text"].tolist(), debug=False)
perplexities

[2620.430116803095,
 2811.311136149555,
 1430.2792730137226,
 1947.3383101792642,
 2452.0691545936147,
 1577.0026226114692,
 3980.1011001356896,
 1558.6299855556538,
 4171.110140058703,
 2198.020584446467]

# First attempt using logist

In [None]:
import torch.nn.functional as F

out = outputs['logits'][0,1:-1].tolist()
out = torch.Tensor(out)
out = F.softmax(out, dim=1)
out = out.tolist()

In [None]:
import math
words = []
current_word = ""
idx = 0
logits = []
corr_enc = []
for i, token in enumerate(tokens):
    if token.startswith('▁') or idx==0:
        idx += 1

        if current_word:
            words.append(current_word)
            logits[len(words) - 1] = [logits[len(words) - 1][j] * x for j, x in enumerate(out[i])]
            corr_enc[len(words) - 1].append(encoded[i])
        # Remove the '▁' and start a new word
        if idx != 1:
            current_word = token[1:]
        else: 
            current_word = token[0:]
        logits.append(out[i])
        corr_enc.append([encoded[i]])
    else:
        # Continue the existing word
        current_word += token

# After the loop ends, add the last word if it exists
if current_word:
    words.append(current_word)

print(words)

In [None]:
probs = []
for x in logits:
    probs.append([x[tok] for tok in encoded])
probs

In [None]:
dictionary = {}
for row_idx, row in enumerate(probs):
    for col_idx, value in enumerate(row):
        token_str = int(encoded[col_idx])
        if token_str not in dictionary:
            dictionary[token_str] = [probs[row_idx][col_idx]]
        else:
            dictionary[token_str].append(probs[row_idx][col_idx])

In [None]:
dict_probs = {}
for j in corr_enc:
    key = str(j)
    arr = [1 for _ in range(len(words))]
    arr = [arr[i] * dictionary[x][i] for i in range(len(arr)) for x in j]
    dict_probs[key] = arr

In [None]:
translation = {}
for i, w in enumerate(words):
    translation[str(corr_enc[i])] = w

old_keys = list(dict_probs.keys())  
for old_key in old_keys:
    new_key = translation[old_key]
    dict_probs[new_key] = dict_probs.pop(old_key)

In [None]:
data = dict_probs  # Your dictionary

min_length = min(len(v) for v in data.values())

sorted_keys = []
remaining_keys = dict(data)  # make a copy we can modify

for i in range(min_length):
    candidates = []
    for k, arr in remaining_keys.items():
        # Determine which indices to average over
        start = max(0, i - 2)
        end = min(len(arr) - 2, i + 2)  # ensure we don't go out of range
        
        # If the array is too short to have these positions, skip
        # Actually, if 'end' < 'start', it means no suitable window
        if end < start:
            continue
        
        # Compute the average over the window [start:end]
        window_vals = arr[start:end+1]
        mse_val = sum([x**2 for x in window_vals]) / len(window_vals)
        avg_val = sum(window_vals) / len(window_vals)
        
        candidates.append((k, avg_val))
    
    if not candidates:
        # No candidates available, break early
        break
    
    # Find the key with the maximum average
    best_key = max(candidates, key=lambda x: x[1])[0]
    sorted_keys.append(best_key)
    
    # Remove the chosen key so it's not selected again
    del remaining_keys[best_key]

print("Sorted keys:", sorted_keys)



In [None]:
current_sequence = []
remaining_words = set(words)

while remaining_words:
    best_word = None
    best_ppl = float('inf')
    # Try adding each remaining word next and compute perplexity
    for w in remaining_words:
        candidate_sequence = current_sequence + [w]
        submission = pd.DataFrame({'id': [0], 'text': [" ".join(candidate_sequence)]})
        candidate_ppl = scorer.get_perplexity(submission["text"].tolist(), debug=False)[0]
        if candidate_ppl < best_ppl:
            best_ppl = candidate_ppl
            best_word = w
    
    current_sequence.append(best_word)
    remaining_words.remove(best_word)