## Load Dataset

Below is a toy example of loading encoded CTR dataset

In [1]:
import pandas as pd
from datasets import Dataset

def row_to_string(row):
    user_id_str = str(row['user_id'])
    other_columns_str = ' '.join([f"{col}_{row[col]}" for col in row.index if col != 'user_id'])
    return f"{user_id_str} {other_columns_str}"

def dataframe_to_string(df):
    return df.apply(row_to_string, axis=1)

# Toy data
data = {
    'user_id': [123, 456, 789],
    'gender': [1, 0, 1],
    'age_group': [0, 1, 0]
}

df = pd.DataFrame(data)
string_representation = dataframe_to_string(df)
print(string_representation.tolist())

data_dict = {
    'row_string': string_representation
}

dataset = Dataset.from_dict(data_dict)


['123 gender_1 age_group_0', '456 gender_0 age_group_1', '789 gender_1 age_group_0']


In [5]:
data_dict

{'row_string': 0    123 gender_1 age_group_0
 1    456 gender_0 age_group_1
 2    789 gender_1 age_group_0
 dtype: object}

## Defining Custom Vocabulary

In [2]:
from transformers import PreTrainedTokenizer

class CustomTokenizer(PreTrainedTokenizer):
    def __init__(self, vocab, **kwargs):        
        self.vocab = vocab
        self.ids_to_tokens = {i: token for i, token in enumerate(self.vocab)}
        self.tokens_to_ids = {token: i for i, token in enumerate(self.vocab)}
        
        super().__init__(**kwargs)

    def _tokenize(self, text):
        tokens = []
        for part in text.split():
            if part.isdigit():
                # Split digits into separate tokens
                tokens.extend(list(part))
            else:
                tokens.append(part)
        return tokens

    def convert_tokens_to_ids(self, tokens):
        return [self.tokens_to_ids[token] for token in tokens]

    def convert_ids_to_tokens(self, ids):
        return [self.ids_to_tokens[_id] for _id in ids]

    def _convert_token_to_id(self, token):
        return self.tokens_to_ids.get(token, self.tokens_to_ids.get('[UNK]'))

    def _convert_id_to_token(self, index):
        return self.ids_to_tokens.get(index, '[UNK]')

    def get_vocab(self):
        return self.tokens_to_ids

# Toy data vocabulary
# Replace with actual CTR data levels.
vocab = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", 
         "gender_0", "gender_1", "gender_2", 
         "age_group_0", "age_group_1"]

tokenizer = CustomTokenizer(vocab=vocab)

row_string = string_representation[0]
tokens = tokenizer.tokenize(row_string)
token_ids = tokenizer.convert_tokens_to_ids(tokens)

print(f"Tokens: {tokens}")
print(f"Token IDs: {token_ids}")


Tokens: ['1', '2', '3', 'gender_1', 'age_group_0']
Token IDs: [1, 2, 3, 11, 13]


In [3]:
# Define tokenizer based on the vocab
tokenizer = CustomTokenizer(vocab=vocab)

# Tokenizing the dataset
def tokenize_example(example):
    tokens = tokenizer.tokenize(example['row_string'])
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
     # Assign labels same as input_ids for causal LM training
    return {"input_ids": token_ids, "labels": token_ids} 

# Apply the tokenizer to the dataset
tokenized_dataset = dataset.map(tokenize_example)

# Show an example from the tokenized dataset
print(tokenized_dataset[0])


Map:   0%|          | 0/3 [00:00<?, ? examples/s]

{'row_string': '123 gender_1 age_group_0', 'input_ids': [1, 2, 3, 11, 13], 'labels': [1, 2, 3, 11, 13]}


## Example Model Training

In [4]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from datasets import Dataset
import torch

In [5]:
# Load the model
def load_model(from_scratch=True):
    if from_scratch:
        model = GPT2LMHeadModel(config=GPT2LMHeadModel.config_class(vocab_size=len(vocab)))
    else:
        model = GPT2LMHeadModel.from_pretrained('gpt2')
        model.resize_token_embeddings(len(vocab))  # Resize embeddings if vocab size differs
    return model

model = load_model(from_scratch=True)  # Set to False to use pre-trained model
save_dir = "./results_custom_tokenizer"

# Define training arguments
training_args = TrainingArguments(
    output_dir=save_dir,
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=2,
    save_steps=10,
    save_total_limit=2,
    logging_dir='./logs',
    logging_steps=10,
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

# Train the model
trainer.train()

Step,Training Loss
10,1.6768
20,0.1338


TrainOutput(global_step=20, training_loss=0.9052952706813813, metrics={'train_runtime': 3.8237, 'train_samples_per_second': 7.846, 'train_steps_per_second': 5.23, 'total_flos': 76550400000.0, 'train_loss': 0.9052952706813813, 'epoch': 10.0})

In [None]:
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

## Conditional Generation

In [8]:
tokenized_test_dataset = tokenized_dataset

In [16]:
# Function to generate text from the model
def generate_text_from_prompt(model, tokenizer, input_ids, max_length=5, prompt_length=3):
    prompt = input_ids[:prompt_length]
    input_ids_tensor = torch.tensor([prompt], dtype=torch.long).to(model.device)

    # Generate the completion
    model.eval()
    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids_tensor,
            max_length=max_length,
            num_return_sequences=1,
            pad_token_id=tokenizer.tokens_to_ids.get('[PAD]', None),
            eos_token_id=tokenizer.tokens_to_ids.get('[EOS]', None)
        )
    
    # Convert the output IDs to tokens and then to string
    generated_sequence = output[0].tolist()
    generated_text = tokenizer.convert_ids_to_tokens(generated_sequence)
    return " ".join(generated_text)

In [17]:
for i, example in enumerate(tokenized_test_dataset):
    print(f"Original row string: {example['row_string']}")
    
    # Use the first half of the input_ids as prompt
    generated_text = generate_text_from_prompt(model, tokenizer, example['input_ids'])
    
    # Print the generated text
    print(f"Generated completion: {generated_text}\n")

Original row string: 123 gender_1 age_group_0
Generated completion: 1 2 3 gender_1 age_group_0

Original row string: 456 gender_0 age_group_1
Generated completion: 4 5 6 gender_0 age_group_1

Original row string: 789 gender_1 age_group_0
Generated completion: 7 8 9 gender_1 age_group_0

