## 데이터

In [1]:
import gc ## for memory
import os
import copy
from math import exp
from collections import Counter
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import transformers
import torch

In [2]:
df_sample = pd.read_csv("sample_submission.csv")

In [3]:
[print(f"{i} : {df_sample.text[i]}\n") for i in range(6)]

0 : advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge

1 : advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and

2 : yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice

3 : yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap

4 : hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle

5 : advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump d

[None, None, None, None, None, None]

In [None]:
# os.system("kaggle competitions submit -c santa-2024 -f sample_submission.csv -m .")

100%|██████████| 1.50k/1.50k [00:00<00:00, 1.55kB/s]


Successfully submitted to Santa 2024 - The Perplexity Permutation Puzzle

0

## evaluation 구현

`-` 퍼플렉시티 산출을 위해서는 `Gemma 2 9B`모델이 필요하다.

In [4]:
## environment variable setting
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# os.system("huggingface-cli login")
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda')


class ParticipantVisibleError(Exception):
    pass


def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str,
    model_path: str = 'google/gemma-2-9b',
    load_in_8bit: bool = False,
    clear_mem: bool = False,
) -> float:
    """
    Calculates the mean perplexity of submitted text permutations compared to an original text.

    Parameters
    ----------
    solution : DataFrame
        DataFrame containing the original text in a column named 'text'.
        Includes a row ID column specified by `row_id_column_name`.

    submission : DataFrame
        DataFrame containing the permuted text in a column named 'text'.
        Must have the same row IDs as the solution.
        Includes a row ID column specified by `row_id_column_name`.

    row_id_column_name : str
        Name of the column containing row IDs.
        Ensures aligned comparison between solution and submission.

    model_path : str, default='/kaggle/input/gemma-2/transformers/gemma-2-9b/2'
        Path to the serialized LLM.

    load_in_8bit : bool, default=False
        Use 8-bit quantization for the model. Requires CUDA.

    clear_mem : bool, default=False
        Clear GPU memory after scoring by clearing the CUDA cache.
        Useful for testing.

    Returns
    -------
    float
        The mean perplexity score. Lower is better.

    Raises
    ------
    ParticipantVisibleError
        If the submission format is invalid or submitted strings are not valid permutations.

    Examples
    --------
    >>> import pandas as pd
    >>> model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
    >>> solution = pd.DataFrame({
    ...     'id': [0, 1],
    ...     'text': ["this is a normal english sentence", "the quick brown fox jumps over the lazy dog"]
    ... })
    >>> submission = pd.DataFrame({
    ...     'id': [0, 1],
    ...     'text': ["sentence english normal a is this", "lazy the over jumps fox brown quick the dog"]
    ... })
    >>> score(solution, submission, 'id', model_path=model_path, clear_mem=True) > 0
    True
    """
    # Check that each submitted string is a permutation of the solution string
    sol_counts = solution.loc[:, 'text'].str.split().apply(Counter)
    sub_counts = submission.loc[:, 'text'].str.split().apply(Counter)
    invalid_mask = sol_counts != sub_counts
    if invalid_mask.any():
        raise ParticipantVisibleError(
            'At least one submitted string is not a valid permutation of the solution string.'
        )

    # Calculate perplexity for the submitted strings
    sub_strings = [
        ' '.join(s.split()) for s in submission['text'].tolist()
    ]  # Split and rejoin to normalize whitespace
    scorer = PerplexityCalculator(
        model_path=model_path,
        load_in_8bit=load_in_8bit,
    )  # Initialize the perplexity calculator with a pre-trained model
    perplexities = scorer.get_perplexity(
        sub_strings
    )  # Calculate perplexity for each submitted string

    if clear_mem:
        # Just move on if it fails. Not essential if we have the score.
        try:
            scorer.clear_gpu_memory()
        except:
            print('GPU memory clearing failed.')

    return float(np.mean(perplexities))


class PerplexityCalculator:
    """
    Calculates perplexity of text using a pre-trained language model.

    Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py

    Parameters
    ----------
    model_path : str
        Path to the pre-trained language model

    load_in_8bit : bool, default=False
        Use 8-bit quantization for the model. Requires CUDA.

    device_map : str, default="auto"
        Device mapping for the model.
    """

    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = False,
        device_map: str = 'auto',
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires CUDA device')
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.eval()

    def get_perplexity(
        self, input_texts: Union[str, List[str]], batch_size: 32
    ) -> Union[float, List[float]]:
        """
        Calculates the perplexity of given texts.

        Parameters
        ----------
        input_texts : str or list of str
            A single string or a list of strings.

        batch_size : int, default=None
            Batch size for processing. Defaults to the number of input texts.

        debug : bool, default=False
            Print debugging information.

        Returns
        -------
        float or list of float
            A single perplexity value if input is a single string,
            or a list of perplexity values if input is a list of strings.

        Examples
        --------
        >>> import pandas as pd
        >>> model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
        >>> scorer = PerplexityCalculator(model_path=model_path)

        >>> submission = pd.DataFrame({
        ...     'id': [0, 1, 2],
        ...     'text': ["this is a normal english sentence", "thsi is a slihgtly misspelled zr4g sentense", "the quick brown fox jumps over the lazy dog"]
        ... })
        >>> perplexities = scorer.get_perplexity(submission["text"].tolist())
        >>> perplexities[0] < perplexities[1]
        True
        >>> perplexities[2] < perplexities[0]
        True

        >>> perplexities = scorer.get_perplexity(["this is a sentence", "another sentence"])
        >>> all(p > 0 for p in perplexities)
        True

        >>> scorer.clear_gpu_memory()
        """
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        
        batches = len(input_texts)//batch_size + (len(input_texts)%batch_size != 0)
        
        for j in range(batches):
            a = j*batch_size
            b = (j+1)*batch_size
            input_batch = input_texts[a:b]
        
            with torch.no_grad():
                # Process each sequence independently
                for text in input_texts:
                    # Explicitly add sequence boundary tokens to the text
                    text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"

                    # Tokenize
                    model_inputs = self.tokenizer(
                        text_with_special,
                        return_tensors='pt',
                        add_special_tokens=False,
                        padding=True
                    )

                    if 'token_type_ids' in model_inputs:
                        model_inputs.pop('token_type_ids')

                    model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                    # Get model output
                    output = self.model(**model_inputs, use_cache=False)
                    logits = output['logits']
                    
                    label = model_inputs["input_ids"]
                    label[label == self.tokenizer.pad_token_type_id] = PAD_TOKEN_LABEL_ID

                    # Shift logits and labels for calculating loss
                    shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                    shift_labels = label[..., 1:].contiguous()  # Drop first input

                    # Calculate token-wise loss
                    loss = self.loss_fct(
                        shift_logits.view(-1, shift_logits.size(-1)),
                        shift_labels.view(-1)
                    )

                    # Calculate average loss
                    loss = self.loss_fct(
                        shift_logits.view(-1, shift_logits.size(-1)),
                        shift_labels.view(-1)
                    )

                    loss = loss.view(len(logits), -1)
                    valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
                    loss = torch.sum(loss, -1) / valid_length

                    loss_list += loss.cpu().tolist()


        ppl = [exp(i) for i in loss_list]

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        """Clears GPU memory by deleting references and emptying caches."""
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, 'model'):
            del self.model
        if hasattr(self, 'tokenizer'):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()

In [5]:
evaluatr = PerplexityCalculator("google/gemma-2-9b")

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [19]:
print(f"score = {sum(evaluatr.get_perplexity(df_sample.text.to_list(), batch_size = 8))/6:.6f}")

score = 2173.306792


> 이렇게 계산하는 게 맞나...?(값이 조금 다름)

## 브루트 포스 알고리즘

In [60]:
for i in range(100) :
    evaluatr.get_perplexity(df_sample.text.to_list()[0], batch_size = 8)

In [6]:
"reindeer mistletoe elf gingerbread family advent"

'reindeer mistletoe elf gingerbread family advent'

In [6]:
from itertools import permutations

In [13]:
batch_size = 1024
model_path = "google/gemma-2-9b"
words = "reindeer mistletoe elf gingerbread family advent"

input_texts = [words+" "+" ".join(p) for p in permutations("chimney fireplace ornament scrooge".split())]

perplexities = evaluatr.get_perplexity(input_texts, batch_size)
best_index = np.argmin(perplexities)
print(input_texts[best_index])

reindeer mistletoe elf gingerbread family advent scrooge chimney fireplace ornament


In [17]:
print(f"{min(perplexities):.6f}")

467.985588
