In [1]:
# Data loading
import pandas as pd
from datasets import Dataset, load_from_disk

# Tokenizer Definition
import json
from transformers import PreTrainedTokenizer

# Model training / generation
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import load_dataset
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Load Dataset

Load dataset created during encoding stage.

In [2]:
dataset_file_path = "encoded_dataset" # Dataset created by text encoding
dataset = load_from_disk(dataset_file_path)

In [3]:
print(dataset['test'][0])

{'text': '137787 ad_click_list_v001_28621 ad_click_list_v001_21424 ad_click_list_v001_24055 ad_click_list_v001_17305 ad_click_list_v001_31470 ad_click_list_v002_1203 ad_click_list_v002_1172 ad_click_list_v002_1112 ad_click_list_v002_1775 ad_click_list_v002_1041 ad_click_list_v003_162 ad_click_list_v003_240 ad_click_list_v003_343 ad_click_list_v003_246 ad_click_list_v003_312 ad_close_list_v001_24107 ad_close_list_v002_1218 ad_close_list_v003_173 hispace_app_tags_43 u_newsCatInterests_140 u_newsCatInterests_112 u_newsCatInterests_16 u_newsCatInterests_176 u_newsCatInterests_207 u_newsCatDislike_0 u_click_ca2_news_112 u_click_ca2_news_168 u_click_ca2_news_140 u_click_ca2_news_207 u_click_ca2_news_15 i_entities_5b212d9859cc262a2d9f4731b8e1890be315e4d27e4d4602bdc993ec955cdfac i_entities_8e1358ee2230f9112e0464bba2cc119224a6849fd6477d6a316eb358e0bbff14 i_entities_064d7e92c0b22a54f65e6193db3f201ed58258a1f17bed583f1359423fcf7331 i_entities_c81ec0fd7307cf51be43e50261cf60c724d1972d358be6ddb8a1f1c

In [4]:
# Load the JSON file containing the unique token vocabulary
with open('vocab_map.json', 'r') as f:
    token_vocab = json.load(f)

## Defining Custom Vocabulary

In [5]:
import re

# Step 1: Define special tokens
BOS_TOKEN = '[BOS]'
EOS_TOKEN = '[EOS]'
PAD_TOKEN = '[PAD]'
UNK_TOKEN = '[UNK]'

# Step 2: Update your token vocabulary to include special tokens (if not already present)
token_vocab.extend([BOS_TOKEN, EOS_TOKEN, PAD_TOKEN, UNK_TOKEN])

# Step 3: Define a regular expression for detecting userIDs.
number_regex = re.compile(r'\d+')

# Step 4: Reinitialize the custom tokenizer to handle digit splitting
class CustomTokenizer(PreTrainedTokenizer):
    def __init__(self, vocab, **kwargs):
        self.vocab = {token: i for i, token in enumerate(vocab)}
        super().__init__(**kwargs)
        self.ids_to_tokens = {i: token for token, i in self.vocab.items()}
        self.bos_token = BOS_TOKEN
        self.eos_token = EOS_TOKEN
        self.pad_token = PAD_TOKEN
        self.unk_token = UNK_TOKEN

    def _tokenize(self, text):
        tokens = []
        # Split text into words
        words = text.split()

        for word in words:
            # If the word is a number, split into individual digits
            if number_regex.fullmatch(word):
                tokens.extend(list(word))  # Split the number into digits
            else:
                tokens.append(word)

        # Add BOS and EOS tokens
        tokens = [self.bos_token] + tokens + [self.eos_token]
        return tokens

    def _convert_token_to_id(self, token):
        return self.vocab.get(token, self.vocab[self.unk_token])

    def _convert_id_to_token(self, index):
        return self.ids_to_tokens.get(index, self.unk_token)

    def get_vocab(self):
        return self.vocab

# Step 5: Initialize the tokenizer
tokenizer = CustomTokenizer(vocab=token_vocab)

# Step 6: Test tokenization with BOS, EOS, and number handling
text = dataset['test'][0]['text']
tokens = tokenizer.tokenize(text)
token_ids = tokenizer.convert_tokens_to_ids(tokens)

print(f"Tokens with BOS/EOS and number handling: {tokens}")
print(f"Token IDs: {token_ids}")

Tokens with BOS/EOS and number handling: ['[BOS]', '1', '3', '7', '7', '8', '7', 'ad_click_list_v001_28621', 'ad_click_list_v001_21424', 'ad_click_list_v001_24055', 'ad_click_list_v001_17305', 'ad_click_list_v001_31470', 'ad_click_list_v002_1203', 'ad_click_list_v002_1172', 'ad_click_list_v002_1112', 'ad_click_list_v002_1775', 'ad_click_list_v002_1041', 'ad_click_list_v003_162', 'ad_click_list_v003_240', 'ad_click_list_v003_343', 'ad_click_list_v003_246', 'ad_click_list_v003_312', 'ad_close_list_v001_24107', 'ad_close_list_v002_1218', 'ad_close_list_v003_173', 'hispace_app_tags_43', 'u_newsCatInterests_140', 'u_newsCatInterests_112', 'u_newsCatInterests_16', 'u_newsCatInterests_176', 'u_newsCatInterests_207', 'u_newsCatDislike_0', 'u_click_ca2_news_112', 'u_click_ca2_news_168', 'u_click_ca2_news_140', 'u_click_ca2_news_207', 'u_click_ca2_news_15', 'i_entities_5b212d9859cc262a2d9f4731b8e1890be315e4d27e4d4602bdc993ec955cdfac', 'i_entities_8e1358ee2230f9112e0464bba2cc119224a6849fd6477d6a3



In [6]:
# Step 1: Tokenize the dataset and pad/truncate to a given max sequence length
def tokenize_function(examples, max_length):
    # Tokenize the text, ensure padding and truncation to max_length, including BOS/EOS tokens
    tokenized = tokenizer(
        examples["text"],
        truncation=True,        # Truncate sequences longer than max_length
        padding="max_length",   # Pad sequences shorter than max_length
        max_length=max_length,  # Define the max length
        add_special_tokens=True # Add BOS/EOS tokens
    )
    
    # In autoregressive training, the labels are the same as the input_ids
    tokenized["labels"] = tokenized["input_ids"].copy()
    
    return tokenized

# Step 2: Define your max_length 
# Max columns in dataset = 120
max_length = 128 

# Step 3: Apply the tokenizer to the dataset, ensuring all examples are padded to max_length
tokenized_datasets = dataset.map(lambda x: tokenize_function(x, max_length), batched=True)


Map: 100%|██████████| 32984/32984 [00:23<00:00, 1430.44 examples/s]
Map: 100%|██████████| 8246/8246 [00:05<00:00, 1448.69 examples/s]


In [7]:
tokenized_datasets['train'][0]

{'text': '282939 ad_click_list_v001_10670 ad_click_list_v001_17693 ad_click_list_v001_27955 ad_click_list_v001_35131 ad_click_list_v001_23285 ad_click_list_v002_1220 ad_click_list_v002_1361 ad_click_list_v002_1518 ad_click_list_v002_1961 ad_click_list_v002_1173 ad_click_list_v003_280 ad_click_list_v003_240 ad_click_list_v003_114 ad_click_list_v003_162 ad_click_list_v003_246 ad_close_list_v001_24107 ad_close_list_v002_1218 ad_close_list_v003_173 hispace_app_tags_47 u_newsCatInterests_216 u_newsCatInterests_0 u_newsCatInterests_169 u_newsCatInterests_171 u_newsCatInterests_168 u_newsCatDislike_0 u_click_ca2_news_86 u_click_ca2_news_169 u_click_ca2_news_171 u_click_ca2_news_168 u_click_ca2_news_78 i_entities_1431cd8e8f17b247cbfb2c67be86a0d8bd246262fe7e99e3c97a91a03ffb31bb i_entities_8c2626d7bc49908761f81d427d4d350ddc5a5904c85c9f8265011d3ea8682d42 i_entities_a6dd053044008fc2e0884d4714caeec969cfbf2488d7a1a2fca4770f6c61f3b2 i_entities_c36449314af7aac8ad712c095dde96c62c684f02b6b2a7ee7307bb47e

## Model Training

In [8]:
# Step 2: Define the GPT-2 model architecture for a distill model
# You can configure the distillation process by reducing the number of layers, heads, etc.
config = GPT2Config(
    vocab_size=len(tokenizer.get_vocab()),
    n_embd=256,  # Smaller embedding size for distillation
    n_layer=6,   # Fewer layers than standard GPT-2
    n_head=4,    # Fewer attention heads
    n_positions=max_length,  # Position embeddings
)

# Initialize a new GPT-2 model with the custom configuration
model = GPT2LMHeadModel(config)

# Step 3: Set up training arguments
training_args = TrainingArguments(
    output_dir="checkpoints/gpt2-distilled",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    save_steps=10_000,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=500,
    evaluation_strategy="steps",
    eval_steps=1000,
    load_best_model_at_end=True
)

# Step 4: Define the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],  # Optional, if validation set is available
)

# Step 5: Train the model
trainer.train()

# Step 6: Save the model
trainer.save_model("checkpoints/gpt2-distilled")



Step,Training Loss,Validation Loss


## Conditional Generation

In [None]:

# Step 1: Function to generate sentences based on first half of input_ids
def complete_sentences(model, tokenizer, tokenized_dataset, max_length):
    completed_sentences = []
    
    # Ensure model is in evaluation mode
    model.eval()
    
    for example in tqdm(tokenized_dataset):
        input_ids = example['input_ids']
        
        # Step 2: Take the first half of the input_ids as the prompt
        half_length = len(input_ids) // 2
        prompt_ids = input_ids[:half_length]
        
        # Step 3: Use the model to generate the complete sentence
        input_ids_tensor = torch.tensor([prompt_ids]).to(model.device)  # Add batch dimension
        generated_ids = model.generate(
            input_ids=input_ids_tensor,
            max_length=max_length,  # Generate up to the max length
            pad_token_id=tokenizer.pad_token_id,  # Ensure proper padding handling
            eos_token_id=tokenizer.eos_token_id  # Stop at EOS token
        )
        
        # Step 4: Convert generated token IDs back to text
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        completed_sentences.append(generated_text)
    
    return completed_sentences

# Generate completed sentences as text
completed_sentences = complete_sentences(model, tokenizer, tokenized_datasets['test'], max_length)

# Completed sentences will now be a list of text where each element is a completed sentence
print(completed_sentences)

In [25]:
import re

def combine_digits(sentence):
    # Use regex to find sequences of single digits and combine them into a single number
    processed_sentence = re.sub(r'(?<=\b)(\d\s)+\d(?=\b)', lambda x: ''.join(x.group(0).split()), sentence)
    return processed_sentence

def process_completed_sentences(completed_sentences):
    # Apply the combine_digits function to each sentence
    processed_sentences = [combine_digits(sentence) for sentence in completed_sentences]
    return processed_sentences

processed_sentences = process_completed_sentences(completed_sentences)
print(processed_sentences[0])


137787 ad_click_list_v001_28621 ad_click_list_v001_21424 ad_click_list_v001_24055 ad_click_list_v001_17305 ad_click_list_v001_31470 ad_click_list_v002_1203 ad_click_list_v002_1172 ad_click_list_v002_1112 ad_click_list_v002_1775 ad_click_list_v002_1041 ad_click_list_v003_162 ad_click_list_v003_240 ad_click_list_v003_343 ad_click_list_v003_246 ad_click_list_v003_312 ad_close_list_v001_24107 ad_close_list_v002_1218 ad_close_list_v003_173 hispace_app_tags_43 u_newsCatInterests_140 u_newsCatInterests_112 u_newsCatInterests_16 u_newsCatInterests_176 u_newsCatInterests_207 u_newsCatDislike_0 u_click_ca2_news_112 u_click_ca2_news_168 u_click_ca2_news_140 u_click_ca2_news_207 u_click_ca2_news_15 i_entities_5b212d9859cc262a2d9f4731b8e1890be315e4d27e4d4602bdc993ec955cdfac i_entities_8e1358ee2230f9112e0464bba2cc119224a6849fd6477d6a316eb358e0bbff14 i_entities_064d7e92c0b22a54f65e6193db3f201ed58258a1f17bed583f1359423fcf7331 i_entities_c81ec0fd7307cf51be43e50261cf60c724d1972d358be6ddb8a1f1cb191adf98 

In [27]:
import os
# Function to save completed sentences to a file
def save_sentences_to_file(completed_sentences, file_path):
    with open(file_path, 'w', encoding='utf-8') as f:
        for sentence in completed_sentences:
            f.write(sentence + '\n')  # Write each sentence on a new line

syn_dir = "synth_data"
os.makedirs(syn_dir, exist_ok=True)
file_path = 'conditional_generation.txt'  # Specify the file path
save_sentences_to_file(processed_sentences, os.path.join(syn_dir,file_path))

print(f"Completed sentences have been saved to {file_path}")

Completed sentences have been saved to conditional_generation.txt
