# Preparing Data for Distallation

Charles Ciampa

In [None]:
import numpy as np
import pandas as pd
from typing import Dict, Callable
import warnings

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sklearn.model_selection import train_test_split
import torch
import torch.nn.functional as F
import os
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader

# Added for notifier import
import sys
sys.path.insert(0, '..')
from notifier import Notifier

NOTIFIER = Notifier(enabled=True)

In [2]:
from huggingface_hub import notebook_login

notebook_login(False)



In [3]:
from huggingface_hub import scan_cache_dir

print(scan_cache_dir())
# delete_strategy = scan_cache_dir().delete_revisions(
#     "8d8ffc158a3bee9fbb03afacdfc347c823c5ec8b"
# )

# print("Will free " + delete_strategy.expected_freed_size_str)



In [None]:
class DistilModelData:
    """ Class will load data from a tokenizer, model, and a dataset. Also a prompt and labels will be provided.
    """
    def __init__(self):
        # Initialize the variables
        self._train_df = None
        self._test_df = None
        self._labels = None
        self._reversed_labels = None
        self._prompt: Callable | None = None
        self._num_examples: int = 0
        self._model: AutoModelForCausalLM = None
        self._tokenizer: AutoTokenizer = None
        self._sample = None
    
    def set_labels(self, labels: Dict[int, str]):
        """Provided a dictionary of labels it will se the labels. The keys are the integer labels in the dataset and the values of the dictionary are the labels for the prompt into the models.

        Args:
            labels (Dict[int, str]): The labels to be saved

        Raises:
            ValueError: A dictionary must be provided as input otherwise an error will be risen.
            ValueError: If not all the keys are integers it will cause issues.
            ValueError: If not all the values are strings it will raise an error.
        """
        if self._train_df is None or self._test_df is None:
            raise ValueError("The train and test dataframes have not be set yet. You must set to ensure that each of the labels in the dataframe have been set.")
        if not isinstance(labels, dict):
            raise ValueError("Labels must be a dictionary")
        if not all(isinstance(k, int) for k in labels.keys()):
            raise ValueError("Label keys must be integers")
        if not all(isinstance(v, str) for v in labels.values()):
            raise ValueError("Label values must be strings")
        label_keys = set(labels.keys())
        train_df_labels = set(self._train_df['label'].unique())
        test_df_labels = set(self._test_df["label"].unique())
        if not train_df_labels.issubset(label_keys) or not test_df_labels.issubset(label_keys):
            raise ValueError(f"The provided labels are missing assigned string values for the following values: {', '.join(train_df_labels.difference(label_keys).union(test_df_labels.difference(label_keys)))}.")
        self._labels = labels
        self._reversed_labels = {v: k for k, v in self._labels.items()}
    
    def set_num_examples_in_prompt(self, num: int = 0):
        """Provided an integer it will set the number of examples in the prompt.

        Args:
            num (int): The number of examples to be saved.

        Raises:
            ValueError: An integer must be provided.
        """
        if not isinstance(num, int):
            raise ValueError("An integer must be provided")
        self._num_examples = num
    
    def set_prompt(self, prompt_func: Callable[[str, dict, pd.DataFrame], str]):
        # Prompt function takes in as such f(string to label, label options, example dataframe) -> prompt string
        self._prompt = prompt_func

    def set_model(self, model_name: str, bnb_config: None | BitsAndBytesConfig = None):
        if not isinstance(model_name, str):
            raise ValueError("A model name must be provided as a string")
        
        self._tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

        self._model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            dtype=torch.float16,
            device_map="auto",
            low_cpu_mem_usage=True,
            trust_remote_code=True,
        )

        print(self._model.device)
    
    def get_inference_prompt(self, text = "[FILL IN]") -> str:
        if self._prompt is None:
            raise ValueError("Prompt has not been set yet.")
        if self._train_df is None:
            raise ValueError("Train dataset has not been set yet.")
        if self._test_df is None:
            raise ValueError("Test dataset has not been set yet.")
        if self._labels is None:
            raise ValueError("Labels have not been set yet.")
        sample = self._sample if self._sample is not None else self._train_df.sample(self._num_examples)
        inference_prompt = self._prompt(text, self._labels, sample)
        return inference_prompt

    def reset_datasets_and_labels(self):
        self._labels = None
        self._train_df = None
        self._test_df = None
    
    def set_datasets_from_path(
        self,
        train_path: str,
        test_path: str,
        rename_columns: Dict[str, str] = {},
        create_columns: None | Callable[[pd.DataFrame], pd.DataFrame] = None,
        ignore_common_text_thresh: float = 0,
    ):
        # Loads the data
        try:
            train_temp = pd.read_parquet(train_path)
            test_temp = pd.read_parquet(test_path)
            # Renames the columns if provided any renames. This is there to help you make sure there is a text and label column as these will be used in this code
            train_temp.rename(columns=rename_columns, inplace=True)
            test_temp.rename(columns=rename_columns, inplace=True)
            # Runs a provided function which modifies the data to ensure that there are columns text and label, and their values are appropriet.
            if create_columns is not None:
                train_temp = create_columns(train_temp)
                test_temp = create_columns(test_temp)
        except Exception as e:
            raise e
        # This is where it actually sets the data. At this point no errors should have occured so its safe to finally set the values. The last checks will be here.
        self.set_datasets(
            train_temp.copy(),
            test_temp.copy(),
            ignore_common_text_thresh=ignore_common_text_thresh,
        )
    

    def set_datasets(self, train_df: pd.DataFrame, test_df: pd.DataFrame, ignore_common_text_thresh: float = 0):
        """Sets the train and test datasets.

        Args:
            train_df (pd.DataFrame): The training dataframe.
            test_df (pd.DataFrame): The testing dataframe.

        Raises:
            ValueError: Both inputs must be pandas DataFrames.
            ValueError: Train DataFrame must have 'text' and 'label' columns.
            ValueError: Test DataFrame must have 'text' and 'label' columns.
            ValueError: Train DataFrame 'label' column must be of integer type.
            ValueError: Test DataFrame 'label' column must be of integer type.
            ValueError: Train DataFrame 'text' column must be of string type.
            ValueError: Test DataFrame 'text' column must be of string type.
            ValueError: Train and Test DataFrames share common text entries. Data leakage detected.
        """
        # Ensures that both of the inputs are DataFrames
        if not isinstance(train_df, pd.DataFrame) or not isinstance(test_df, pd.DataFrame):
            raise ValueError("Both inputs must be pandas DataFrames.")
        
        # Checks that there is a labels and text column
        if "text" not in train_df.columns or "label" not in train_df.columns:
            raise ValueError("Train DataFrame must have 'text' and 'label' columns.")
        if "text" not in test_df.columns or "label" not in test_df.columns:
            raise ValueError("Test DataFrame must have 'text' and 'label' columns.")
        
        # Ensure that the labels are of the integer type
        if not pd.api.types.is_integer_dtype(train_df["label"]):
            raise ValueError("Train DataFrame 'label' column must be of integer type.")
        if not pd.api.types.is_integer_dtype(test_df["label"]):
            raise ValueError("Test DataFrame 'label' column must be of integer type.")
        
        # Ensure that the text columns are a string value
        if not pd.api.types.is_string_dtype(train_df["text"]):
            raise ValueError("Train DataFrame 'text' column must be of string type")
        if not pd.api.types.is_string_dtype(test_df["text"]):
            raise ValueError("Test DataFrame 'text' column must be of string type")
        
        # Check for overlapping data between train and test sets based on the 'text' column
        common_texts = set(train_df["text"]).intersection(set(test_df["text"]))
        if common_texts:
            perc = len(common_texts) / len(test_df) 
            err = f"Data leakage detected! Train and Test DataFrames share {len(common_texts)} ({perc:.2%} of testing dataset) common text entries."
            if perc > ignore_common_text_thresh:
                raise ValueError(err)
            else:
                warnings.warn(err)
        self._train_df = train_df
        self._test_df = test_df

    def distil_labels(self, batch_size: int = 8, label_prob_prefex: str = 'label_'):
        if self._labels is None:
            raise ValueError("Labels must be set.")
        if self._train_df is None or self._test_df is None:
            raise ValueError("Datasets must be set.")
        if self._model is None or self._tokenizer is None:
            raise ValueError("Model and Tokenizer must be set")
        if self._prompt is None:
            raise ValueError("Prompt must be set.")
        if self._model is None or self._tokenizer is None:
            raise ValueError("Model and Tokenizer have not been set yet.")
        
        # 1. Pre-calculate label token IDs (do this once, not in the loop)
        label_token_map = {}
        with torch.inference_mode():
            for label_str in self._labels.values():
                # Add space because many tokenizers are space-sensitive
                tokens = self._tokenizer.encode(label_str, add_special_tokens=False)
                label_token_map[label_str] = tokens[0]
        
        self._tokenizer.padding_side = "left"
        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = self._tokenizer.eos_token

        ordered_labels = list(self._labels.values())
        ordered_ids = [label_token_map[l] for l in ordered_labels]
        target_token_ids = torch.tensor(ordered_ids, device=self._model.device)
        def process_batches(prompts, desc):
            """Created by Claude Sonnet 4.5"""
            # all_probs = []
            
            with torch.inference_mode():
                prompt_lengths = [(i, len(self._tokenizer.encode(p, add_special_tokens=False))) for i, p in enumerate(prompts)]
                prompt_lengths.sort(key=lambda x: x[1])

                sorted_indices = [x[0] for x in prompt_lengths]
                sorted_prompts = [prompts[i] for i in sorted_indices]
                batch_results = [None] * len(prompts)
                for i in tqdm(range(0, len(sorted_prompts), batch_size), desc=desc):
                    batch_prompts = sorted_prompts[i : i + batch_size]
                    batch_indices = sorted_indices[i : i + batch_size]

                    # Tokenize batch
                    inputs = self._tokenizer(
                        batch_prompts,
                        return_tensors="pt",
                        padding=True,
                        truncation=True
                    ).to(self._model.device)
                    
                    # Forward pass
                    outputs = self._model(**inputs)
                    
                    # Get logits of last token for each sequence
                    # Shape: [batch_size, vocab_size]
                    next_token_logits = outputs.logits[:, -1, :]
                    
                    # Extract logits only for our label tokens
                    # Shape: [batch_size, num_labels]
                    selected_logits = next_token_logits[:, target_token_ids]
                    
                    # Softmax over only the selected labels
                    probs = F.softmax(selected_logits, dim=-1)
                    
                    # Convert to list of dicts
                    probs_cpu = probs.detach().cpu().numpy()
                    for j, row_probs in enumerate(probs_cpu):
                        prob_dict = {
                            self._reversed_labels[label_str]: p
                            for label_str, p in zip(ordered_labels, row_probs)
                        }
                        batch_results[batch_indices[j]] = prob_dict
                    # CRITICAL: Delete GPU tensors explicitly
                    del inputs, outputs, next_token_logits, selected_logits, probs, probs_cpu
                    
                    # Clear CUDA cache periodically (every 10 batches)
                    # if i % (batch_size * 10) == 0:
                    torch.cuda.empty_cache()
                torch.cuda.empty_cache()
                return batch_results
        
        # Create the prompt for the training
        train_prompts = [self._prompt(row["text"], self._labels, self._train_df.drop(i).sample(self._num_examples)) for i, row in self._train_df.iterrows()]
        
        train_probs = process_batches(train_prompts, "Getting Probability of Labels Training Dataset")

        train_probs = pd.DataFrame(train_probs).add_prefix(label_prob_prefex)

        self._train_df = pd.concat([self._train_df.reset_index(drop=True), train_probs], axis=1)

        # Get text exampls for the testing prompts
        self._sample = self._train_df.sample(self._num_examples)

        # Create test prompts
        test_prompts = [self._prompt(row["text"], self._labels, self._sample) for i, row in self._test_df.iterrows()]

        # Get the probabilities
        test_probs = process_batches(test_prompts, "Getting Probability of Labels Testing Dataset")

        test_probs = pd.DataFrame(test_probs).add_prefix(label_prob_prefex)

        self._test_df = pd.concat([self._test_df.reset_index(drop=True), test_probs], axis=1)

    def folder_export(self, path: str):
        if self._test_df  is None or self._train_df is None:
            raise ValueError("The datasets have not been set.")
        self._train_df.to_csv(f"{path}train.csv", index=False)
        self._test_df.to_csv(f"{path}test.csv", index=False)
    
    def export_files(self, train_path: str, test_path: str):
        if self._test_df is None or self._train_df is None:
            raise ValueError("The datasets have not been set.")
        self._train_df.to_csv(train_path, index=False)
        self._test_df.to_csv(test_path, index=False)


In [5]:
df_total = pd.concat([
    pd.read_parquet("hf://datasets/stanfordnlp/imdb/plain_text/train-00000-of-00001.parquet"),
    pd.read_parquet("hf://datasets/stanfordnlp/imdb/plain_text/test-00000-of-00001.parquet")], ignore_index=True).drop_duplicates('text', ignore_index=True)
df_total.head(5)

Unnamed: 0,text,label
0,I rented I AM CURIOUS-YELLOW from my video sto...,0
1,"""I Am Curious: Yellow"" is a risible and preten...",0
2,If only to avoid making this type of film in t...,0
3,This film was probably inspired by Godard's Ma...,0
4,"Oh, brother...after hearing about this ridicul...",0


In [6]:
df_train, df_test = train_test_split(df_total, test_size=0.5, random_state=6120)

In [7]:
print(df_train['label'].value_counts())
print(df_test['label'].value_counts())

label
1    12433
0    12358
Name: count, dtype: int64
label
1    12451
0    12340
Name: count, dtype: int64


In [8]:
model_distallation = DistilModelData()

# # "hf://datasets/stanfordnlp/imdb/" + splits["train"])
# splits = {'train': 'plain_text/train-00000-of-00001.parquet', 'test': 'plain_text/test-00000-of-00001.parquet', 'unsupervised': 'plain_text/unsupervised-00000-of-00001.parquet'}
model_distallation.set_datasets(train_df=df_train, test_df=df_test)

model_distallation.set_labels({0: "Negative", 1: "Positive"})

In [None]:
# model_distallation.set_model("meta-llama/Meta-Llama-3.1-8B-Instruct")
model_distallation.set_model("dphn/Dolphin3.0-Llama3.1-8B")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

cuda:0


In [10]:
def model_prompt(ex: str, labels: dict, examples: pd.DataFrame) -> str:
    output = f"Classify the sentiment of the following texts as either {', '.join(list(labels.values())[:-1])}, or {list(labels.values())[-1]}.\n\n"
    if len(examples) > 0:
        for i, row in examples.iterrows():
            output += f'Text: {row["text"]}\nSentiment: {labels[row["label"]]}\n\n'
    output += f'Text: {ex}\nSentiment: '
    return output

model_distallation.set_prompt(model_prompt)

In [11]:
model_distallation.set_num_examples_in_prompt(3)

In [None]:
model_distallation.distil_labels(batch_size=3)
NOTIFIER.send_notification("The distillation of the labels has been completed successfully.")
print(model_distallation.get_inference_prompt())

Getting Probability of Labels Training Dataset:   0%|          | 0/8264 [00:00<?, ?it/s]

Getting Probability of Labels Testing Dataset:   0%|          | 0/8264 [00:00<?, ?it/s]

In [None]:
model_distallation.folder_export("../data/dolphin/")
NOTIFIER.send_notification("The export of the distilled data has been completed successfully.")

In [16]:
model_distallation._train_df

Unnamed: 0,text,label,label_0,label_1
0,The volleyball genre is strangely overlooked b...,0,0.046021,0.954102
1,I got interested in this movie because somebod...,0,0.782715,0.217285
2,"Sure, I like some indie films. A lot, actually...",0,0.955566,0.044525
3,Blademaster is definitely a memorable entry in...,1,0.069031,0.931152
4,This World War II Popeye cartoon had some very...,1,0.002823,0.997070
...,...,...,...,...
24786,No reason to bother renting this flick. From t...,0,0.978516,0.021454
24787,What gives Anthony Minghella the right to ruin...,0,0.969727,0.030212
24788,My favorite film this year. Great characters a...,1,0.006771,0.993164
24789,Just what is the point of this film? It starts...,0,0.970703,0.029419
