In [1]:
"""Evaluation metric for Santa 2024."""

import gc
import os
from math import exp
from collections import Counter
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import transformers
import torch

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class ParticipantVisibleError(Exception):
    pass


def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str,
    model_path: str = '/kaggle/input/gemma-2/transformers/gemma-2-9b/2',
    load_in_8bit: bool = False,
    clear_mem: bool = False,
) -> float:
    """
    Calculates the mean perplexity of submitted text permutations compared to an original text.

    Parameters
    ----------
    solution : DataFrame
        DataFrame containing the original text in a column named 'text'.
        Includes a row ID column specified by `row_id_column_name`.

    submission : DataFrame
        DataFrame containing the permuted text in a column named 'text'.
        Must have the same row IDs as the solution.
        Includes a row ID column specified by `row_id_column_name`.

    row_id_column_name : str
        Name of the column containing row IDs.
        Ensures aligned comparison between solution and submission.

    model_path : str, default='/kaggle/input/gemma-2/transformers/gemma-2-9b/2'
        Path to the serialized LLM.

    load_in_8bit : bool, default=False
        Use 8-bit quantization for the model. Requires CUDA.

    clear_mem : bool, default=False
        Clear GPU memory after scoring by clearing the CUDA cache.
        Useful for testing.

    Returns
    -------
    float
        The mean perplexity score. Lower is better.

    Raises
    ------
    ParticipantVisibleError
        If the submission format is invalid or submitted strings are not valid permutations.
    """
    # Check that each submitted string is a permutation of the solution string
    sol_counts = solution.loc[:, 'text'].str.split().apply(Counter)
    sub_counts = submission.loc[:, 'text'].str.split().apply(Counter)
    invalid_mask = sol_counts != sub_counts
    if invalid_mask.any():
        raise ParticipantVisibleError(
            'At least one submitted string is not a valid permutation of the solution string.'
        )

    # Calculate perplexity for the submitted strings
    sub_strings = [
        ' '.join(s.split()) for s in submission['text'].tolist()
    ]  # Split and rejoin to normalize whitespace
    scorer = PerplexityCalculator(
        model_path=model_path,
        load_in_8bit=load_in_8bit,
    )  # Initialize the perplexity calculator with a pre-trained model
    perplexities = scorer.get_perplexity(
        sub_strings
    )  # Calculate perplexity for each submitted string

    if clear_mem:
        # Just move on if it fails. Not essential if we have the score.
        try:
            scorer.clear_gpu_memory()
        except:
            print('GPU memory clearing failed.')

    return float(np.mean(perplexities))


class PerplexityCalculator:
    """
    Calculates perplexity of text using a pre-trained language model.

    Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py

    Parameters
    ----------
    model_path : str
        Path to the pre-trained language model

    load_in_8bit : bool, default=False
        Use 8-bit quantization for the model. Requires CUDA.

    device_map : str, default="auto"
        Device mapping for the model.
    """

    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = False,
        device_map: str = 'auto',
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires CUDA device')
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.eval()

    def get_tokenizer(self):
        return self.tokenizer



    def get_perplexity(
        self, input_texts: Union[str, List[str]], debug=False
    ) -> Union[float, List[float]]:
        """
        Calculates the perplexity of given texts.

        Parameters
        ----------
        input_texts : str or list of str
            A single string or a list of strings.

        batch_size : int, default=None
            Batch size for processing. Defaults to the number of input texts.

        debug : bool, default=False
            Print debugging information.

        Returns
        -------
        float or list of float
            A single perplexity value if input is a single string,
            or a list of perplexity values if input is a list of strings.
        """
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        with torch.no_grad():
            # Process each sequence independently
            for text in input_texts:
                # Explicitly add sequence boundary tokens to the text
                text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"

                # Tokenize
                model_inputs = self.tokenizer(
                    text_with_special,
                    return_tensors='pt',
                    add_special_tokens=False,
                )

                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                # Get model output
                output = self.model(**model_inputs, use_cache=False)
                logits = output['logits']

                # Shift logits and labels for calculating loss
                shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                shift_labels = model_inputs['input_ids'][..., 1:].contiguous()  # Drop first input

                # Calculate token-wise loss
                loss = self.loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1)
                )

                # Calculate average loss
                sequence_loss = loss.sum() / len(loss)
                loss_list.append(sequence_loss.cpu().item())

                # Debug output
                if debug:
                    print(f"\nProcessing: '{text}'")
                    print(f"With special tokens: '{text_with_special}'")
                    print(f"Input tokens: {model_inputs['input_ids'][0].tolist()}")
                    print(f"Target tokens: {shift_labels[0].tolist()}")
                    print(f"Input decoded: {self.tokenizer.decode(model_inputs['input_ids'][0])}")
                    print(f"Target decoded: {self.tokenizer.decode(shift_labels[0])}")
                    print(f"Individual losses: {loss.tolist()}")
                    print(f"Average loss: {sequence_loss.item():.4f}")

        ppl = [exp(i) for i in loss_list]

        if debug:
            print("\nFinal perplexities:")
            for text, perp in zip(input_texts, ppl):
                print(f"Text: '{text}'")
                print(f"Perplexity: {perp:.2f}")

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        """Clears GPU memory by deleting references and emptying caches."""
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, 'model'):
            del self.model
        if hasattr(self, 'tokenizer'):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()



In [2]:
import pandas as pd
model_path = "./model/gemma"
scorer = PerplexityCalculator(model_path=model_path, load_in_8bit=True)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

## Example 
``` py
submission = pd.DataFrame({
'id': [0, 1, 2],
'text': ["this is a normal english sentence", "thsi is a slihgtly misspelled zr4g sentense", "the quick brown fox jumps over the lazy dog"]
})
perplexities = scorer.get_perplexity(submission["text"].tolist(), debug=True)

sol_counts = submission.loc[:, 'text'].str.split().apply(Counter)
sol_counts


```

In [3]:
#sentence = "thsi is a slihgtly misspelled zr4g sentense"
sentence = "A christmas sentence will appear after the dot. advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge"
#sentence = "This is a sentence: I am from Peru"
#sentence = "advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle"
encoded = scorer.tokenizer.encode(sentence, add_special_tokens=False)
tokens = scorer.tokenizer.convert_ids_to_tokens(encoded)


"""text_with_special = f"{scorer.tokenizer.bos_token}{sentence}{scorer.tokenizer.eos_token}"
minputs = scorer.tokenizer(text_with_special, return_tensors='pt', add_special_tokens=False,)"""

'text_with_special = f"{scorer.tokenizer.bos_token}{sentence}{scorer.tokenizer.eos_token}"\nminputs = scorer.tokenizer(text_with_special, return_tensors=\'pt\', add_special_tokens=False,)'

In [4]:
scorer.tokenizer.decode(136507)

' gingerbread'

# Actual attempt

In [5]:
import torch.nn.functional as F
import math

dictionary = {}
init_tokens = {}
words = []
current_word = ""
idx = 0
corr_enc = []
for i, token in enumerate(tokens):
    if token.startswith('▁') or idx==0:
        idx += 1

        if current_word:
            words.append(current_word)
            dictionary[current_word] = corr_enc[len(words) - 1]
        # Remove the '▁' and start a new word
        if idx != 1:
            current_word = token[1:]
        else: 
            current_word = token[0:]
        corr_enc.append([encoded[i]])
        init_tokens[encoded[i]] = current_word
    else:
        # Continue the existing word
        corr_enc[len(words)].append(encoded[i])
        current_word += token

# After the loop ends, add the last word if it exists
if current_word:
    words.append(current_word)
    dictionary[current_word] = corr_enc[len(words) - 1]

print(dictionary)

{'A': [235280], 'christmas': [14496], 'sentence': [13060], 'will': [877], 'appear': [4824], 'after': [1452], 'the': [573], 'dot.': [12846, 235265], 'advent': [12002], 'chimney': [67905], 'elf': [52931], 'family': [2730], 'fireplace': [43485], 'gingerbread': [136507], 'mistletoe': [7727, 165493], 'ornament': [29138], 'reindeer': [103360], 'scrooge': [1513, 80108, 541]}


In [6]:
corr_enc[8]

[12002]

In [33]:
def get_prob(tokens: list, model, new_sentence):
    new_sent = new_sentence.copy()
    total_log_prob = 1.0
    for tok in tokens:
        w = scorer.tokenizer.decode(tok)
        text_with_special = f"{scorer.tokenizer.bos_token}{' '.join(new_sentence)}{w}{scorer.tokenizer.eos_token}"

        minputs = scorer.tokenizer(text_with_special, return_tensors='pt', add_special_tokens=False,)
        
        outputs = scorer.model(**minputs)
        logits = outputs.logits
        shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
        shift_labels = minputs['input_ids'][..., 1:].contiguous()  # Drop first input

        #applying softmax only to the 9+i token
        #probs = torch.softmax(shift_logits[0][-1], dim=0)

        token_prob = shift_logits[0][-1][tok].item()

        total_log_prob *= token_prob

        # Append the chosen token to the current_ids for the next iteration
        new_sent.append(w)
    del minputs
    del outputs
    del logits
    del shift_logits
    del shift_labels
    return total_log_prob

In [34]:
from tqdm import tqdm
solutions = []
words = list(dictionary.keys())

for j in range(len(words) - 8):
    new_sentence = sentence.split()[:8] + [words[8+j]]
    words_to_use = set(list(dictionary.keys())[8:]) - set([words[8+j]])
    init_tokens_ = set(list(init_tokens.keys())[8:]) - set(corr_enc[8+j])
    desire_sentence = set(encoded[9:]) - set(dictionary[words[8+j]])
    for i in tqdm(range(len(words) - 10)):
        
        probs = {}

        for key in words_to_use:
            probs[key] = get_prob(dictionary[key], scorer.model, new_sentence)

        #get index of highest probability
        best_word = max(probs, key=probs.get)

        new_sentence.append(best_word)

        desire_sentence -= set(dictionary[best_word])
        words_to_use -= set([best_word])
        if len(desire_sentence) == 0:
            break
    solutions.append(' '.join(new_sentence[8:]))

solutions = pd.DataFrame(
    {'id': range(len(solutions)),
    'text': solutions})
        

100%|██████████| 8/8 [00:07<00:00,  1.08it/s]
100%|██████████| 8/8 [00:07<00:00,  1.10it/s]
100%|██████████| 8/8 [00:07<00:00,  1.11it/s]
100%|██████████| 8/8 [00:07<00:00,  1.13it/s]
100%|██████████| 8/8 [00:07<00:00,  1.10it/s]
100%|██████████| 8/8 [00:07<00:00,  1.07it/s]
100%|██████████| 8/8 [00:07<00:00,  1.12it/s]
100%|██████████| 8/8 [00:07<00:00,  1.05it/s]
100%|██████████| 8/8 [00:07<00:00,  1.07it/s]
100%|██████████| 8/8 [00:06<00:00,  1.15it/s]


In [None]:
from tqdm import tqdm
solutions = []
words = list(dictionary.keys())
for j in range(len(words) - 8):
    new_sentence = sentence.split()[:8] + [words[8+j]]
    init_tokens_ = set(list(init_tokens.keys())[8:]) - set(corr_enc[8+j])
    desire_sentence = set(encoded[9:]) - set(dictionary[words[8+j]])
    for i in tqdm(range(len(words) - 9)):
        text_with_special = f"{scorer.tokenizer.bos_token}{' '.join(new_sentence)}{scorer.tokenizer.eos_token}"
        minputs = scorer.tokenizer(text_with_special, return_tensors='pt', add_special_tokens=False,)
        
        outputs = scorer.model(**minputs)
        logits = outputs.logits
        shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
        shift_labels = minputs['input_ids'][..., 1:].contiguous()  # Drop first input

        #applying softmax only to the 9+i token
        #shift_logits[0][9+i] = F.softmax(shift_logits[0][9+i], dim=0)

        for key, val in dictionary.items():
            # val is array of tokens. do shift_logits[0][9+i][v] = prod (shift_logits[0][9+i][v] for v in val)
            product_value = torch.prod(shift_logits[0][9+i][val])
            for v in val:
                shift_logits[0][9+i][v] = product_value
                
        max_index = torch.topk(shift_logits[0][9+i], 256000).indices #beggining of the sentence


        

        # Find the lower index of all the words in desire_sentence
        index = []
        if len(desire_sentence) == 0:
            break
        
        tok = 0
        for enc in max_index:
            if enc.item() in init_tokens_:
                tok = enc.item()
                init_tokens_ = init_tokens_ - set([tok])
                break
        #find the complete word in dictionary given the token
        for word in dictionary:
            if tok in dictionary[word]:
                new_sentence = new_sentence + [word]
                desire_sentence = desire_sentence - set(dictionary[word])
                break

        del minputs
        del outputs
        del logits
        del shift_logits
        del shift_labels

    solutions.append(' '.join(new_sentence[8:]))

solutions = pd.DataFrame(
    {'id': range(len(solutions)),
    'text': solutions})
        

100%|██████████| 9/9 [00:03<00:00,  2.36it/s]
100%|██████████| 9/9 [00:04<00:00,  2.24it/s]
100%|██████████| 9/9 [00:03<00:00,  2.31it/s]
100%|██████████| 9/9 [00:03<00:00,  2.31it/s]
100%|██████████| 9/9 [00:03<00:00,  2.29it/s]
100%|██████████| 9/9 [00:03<00:00,  2.36it/s]
100%|██████████| 9/9 [00:03<00:00,  2.39it/s]
100%|██████████| 9/9 [00:03<00:00,  2.25it/s]
100%|██████████| 9/9 [00:03<00:00,  2.28it/s]
100%|██████████| 9/9 [00:03<00:00,  2.26it/s]


In [36]:
perplexities = scorer.get_perplexity(solutions["text"].tolist(), debug=False)
perplexities

[2285.5797704346282,
 1917.1476280156185,
 5272.7772756597515,
 3949.1277076439605,
 3051.6493685247633,
 2064.849249361922,
 3075.5837511620985,
 3173.213259472856,
 1516.5886008725572,
 2241.3726559431707]

In [35]:
pd.set_option('display.max_colwidth', None)
print(solutions)

   id  \
0   0   
1   1   
2   2   
3   3   
4   4   
5   5   
6   6   
7   7   
8   8   
9   9   

                                                                           text  
0  advent scrooge mistletoe chimney fireplace elf gingerbread ornament reindeer  
1  chimney scrooge mistletoe elf gingerbread reindeer ornament fireplace advent  
2  elf scrooge mistletoe chimney gingerbread reindeer ornament fireplace advent  
3  family scrooge mistletoe chimney fireplace elf gingerbread ornament reindeer  
4  fireplace scrooge chimney mistletoe gingerbread elf ornament reindeer advent  
5  gingerbread scrooge chimney mistletoe elf reindeer ornament fireplace advent  
6  mistletoe scrooge chimney fireplace elf gingerbread ornament reindeer advent  
7  ornament scrooge fireplace mistletoe chimney gingerbread elf reindeer advent  
8  reindeer scrooge chimney mistletoe gingerbread elf ornament fireplace advent  
9  scrooge mistletoe chimney fireplace elf gingerbread ornament reindeer advent 

# First attempt using logist

In [30]:
import torch.nn.functional as F

out = outputs['logits'][0,1:-1].tolist()
out = torch.Tensor(out)
out = F.softmax(out, dim=1)
out = out.tolist()

In [40]:
import math
words = []
current_word = ""
idx = 0
logits = []
corr_enc = []
for i, token in enumerate(tokens):
    if token.startswith('▁') or idx==0:
        idx += 1

        if current_word:
            words.append(current_word)
            logits[len(words) - 1] = [logits[len(words) - 1][j] * x for j, x in enumerate(out[i])]
            corr_enc[len(words) - 1].append(encoded[i])
        # Remove the '▁' and start a new word
        if idx != 1:
            current_word = token[1:]
        else: 
            current_word = token[0:]
        logits.append(out[i])
        corr_enc.append([encoded[i]])
    else:
        # Continue the existing word
        current_word += token

# After the loop ends, add the last word if it exists
if current_word:
    words.append(current_word)

print(words)

['I', 'am', 'from', 'Peru']


In [48]:
probs = []
for x in logits:
    probs.append([x[tok] for tok in encoded])
probs

[[5.720867918121862e-17,
  2.914890042105453e-07,
  6.084304055318327e-10,
  3.438256276667309e-26],
 [3.575450173774627e-20,
  1.667731239999664e-13,
  2.75705662393314e-08,
  8.030782328181487e-16],
 [3.4125374978022107e-18,
  1.3480570017581046e-12,
  2.362354899526456e-09,
  2.0935555257887902e-07],
 [5.848318096468574e-07,
  2.372890776314307e-05,
  0.0004339563602115959,
  0.00017261592438444495]]

In [49]:
dictionary = {}
for row_idx, row in enumerate(probs):
    for col_idx, value in enumerate(row):
        token_str = int(encoded[col_idx])
        if token_str not in dictionary:
            dictionary[token_str] = [probs[row_idx][col_idx]]
        else:
            dictionary[token_str].append(probs[row_idx][col_idx])

In [50]:
dict_probs = {}
for j in corr_enc:
    key = str(j)
    arr = [1 for _ in range(len(words))]
    arr = [arr[i] * dictionary[x][i] for i in range(len(arr)) for x in j]
    dict_probs[key] = arr

In [51]:
translation = {}
for i, w in enumerate(words):
    translation[str(corr_enc[i])] = w

old_keys = list(dict_probs.keys())  
for old_key in old_keys:
    new_key = translation[old_key]
    dict_probs[new_key] = dict_probs.pop(old_key)

In [52]:
data = dict_probs  # Your dictionary

min_length = min(len(v) for v in data.values())

sorted_keys = []
remaining_keys = dict(data)  # make a copy we can modify

for i in range(min_length):
    candidates = []
    for k, arr in remaining_keys.items():
        # Determine which indices to average over
        start = max(0, i - 2)
        end = min(len(arr) - 2, i + 2)  # ensure we don't go out of range
        
        # If the array is too short to have these positions, skip
        # Actually, if 'end' < 'start', it means no suitable window
        if end < start:
            continue
        
        # Compute the average over the window [start:end]
        window_vals = arr[start:end+1]
        mse_val = sum([x**2 for x in window_vals]) / len(window_vals)
        avg_val = sum(window_vals) / len(window_vals)
        
        candidates.append((k, avg_val))
    
    if not candidates:
        # No candidates available, break early
        break
    
    # Find the key with the maximum average
    best_key = max(candidates, key=lambda x: x[1])[0]
    sorted_keys.append(best_key)
    
    # Remove the chosen key so it's not selected again
    del remaining_keys[best_key]

print("Sorted keys:", sorted_keys)



Sorted keys: ['am', 'I', 'Peru', 'from']


In [None]:
current_sequence = []
remaining_words = set(words)

while remaining_words:
    best_word = None
    best_ppl = float('inf')
    # Try adding each remaining word next and compute perplexity
    for w in remaining_words:
        candidate_sequence = current_sequence + [w]
        submission = pd.DataFrame({'id': [0], 'text': [" ".join(candidate_sequence)]})
        candidate_ppl = scorer.get_perplexity(submission["text"].tolist(), debug=False)[0]
        if candidate_ppl < best_ppl:
            best_ppl = candidate_ppl
            best_word = w
    
    current_sequence.append(best_word)
    remaining_words.remove(best_word)

# Totally brute force

In [27]:
from itertools import combinations, permutations
from nltk import ngrams
from tqdm import tqdm

def get_permutations(words):
    perms = []
    for perm in permutations(words, len(words)):
        perms.append(perm)
    return perms

def get_permutations2(words):
    perms = []
    for perm in ngrams(words, len(words)):
        perms.append(perm)
    return perms

comb = get_permutations(words)
print(len(comb))
print(len(words))

best_ppl = float('inf')
best_perm = None
for c in tqdm(comb):
    submission = pd.DataFrame({
    'id': [0],
    'text': [" ".join(c)]
    })
    perplexities = scorer.get_perplexity(submission["text"].tolist(), debug=False)
    if perplexities[0] < best_ppl:
        best_ppl = perplexities[0]
        best_perm = c

3628800
10


  0%|          | 4286/3628800 [09:10<129:16:30,  7.79it/s]


KeyboardInterrupt: 