In [None]:
from huggingface_hub import login

login(token = "YOUR_HUGGINGFACE_TOKEN")

In [13]:
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM

original_tokenizer = AutoTokenizer.from_pretrained("nickypro/tinyllama-110M")
model = LlamaForCausalLM.from_pretrained("nickypro/tinyllama-110M")

original_tokenizer.padding_side = "right"

In [14]:
import torch
from transformers import pipeline
import torch.nn as nn

original_vocab_size, embedding_dim = model.model.embed_tokens.weight.shape

extra_embedding_1 = nn.Embedding(original_vocab_size, embedding_dim)
extra_embedding_2 = nn.Embedding(original_vocab_size, embedding_dim)

nn.init.xavier_uniform_(extra_embedding_1.weight)
nn.init.xavier_uniform_(extra_embedding_2.weight)

extra_embedding_1.weight.requires_grad = True
extra_embedding_2.weight.requires_grad = True

# Add the new embeddings as attributes
model.model.extra_embedding_1 = extra_embedding_1
model.model.extra_embedding_2 = extra_embedding_2

# Copy weights
model.model.extra_embedding_1.weight.data.copy_(model.model.embed_tokens.weight.data)
model.model.extra_embedding_2.weight.data.copy_(model.model.embed_tokens.weight.data)


tensor([[-0.0224, -0.0145,  0.0013,  ...,  0.0682,  0.0056, -0.0279],
        [-0.0008, -0.0003, -0.0180,  ...,  0.0243, -0.0092,  0.0170],
        [-0.0224, -0.0145,  0.0013,  ...,  0.0681,  0.0056, -0.0279],
        ...,
        [-0.0224, -0.0145,  0.0013,  ...,  0.0682,  0.0056, -0.0279],
        [-0.0224, -0.0145,  0.0013,  ...,  0.0681,  0.0056, -0.0279],
        [-0.0224, -0.0145,  0.0013,  ...,  0.0682,  0.0056, -0.0279]])

In [15]:
from typing import List, Optional, Union
from cachetools import Cache
import types

def modified_forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
    num_logits_to_keep: int = 0,
    **kwargs
):
    if input_ids is None and inputs_embeds is None:
        raise ValueError("You have to specify either input_ids or inputs_embeds")

    batch_size = input_ids.shape[0]
    combined_embeds = []

    for batch_idx in range(batch_size):
        str_input_ids = " ".join([str(i) for i in input_ids[batch_idx].tolist()])
        input_ids_parts = str_input_ids.split(" 200000 ")
        # print(len(input_ids_parts), input_ids_parts)
        bpe = torch.tensor([list(map(int, input_ids_parts[0].split(" ")))], 
                           device=input_ids.device)
        wordpiece = torch.tensor([list(map(int, input_ids_parts[1].split(" ")))], 
                               device=input_ids.device)
        unigram = torch.tensor([list(map(int, input_ids_parts[2].split(" ")))], 
                                     device=input_ids.device)
        
        bpe_embedding = self.model.embed_tokens(bpe)
        wordpiece_embedding = self.model.extra_embedding_1(wordpiece)
        unigram_embedding = self.model.extra_embedding_2(unigram)
        
        # print(f"Shapes: BPE {bpe_embedding.shape}, Unigram {unigram_embedding.shape}, SentencePiece {sentencepiece_embedding.shape}")
        
        min_length = min(bpe_embedding.shape[1], wordpiece_embedding.shape[1], unigram_embedding.shape[1])
        bpe_embedding = bpe_embedding[:, :min_length, :]
        wordpiece_embedding = wordpiece_embedding[:, :min_length, :]
        unigram_embedding = unigram_embedding[:, :min_length, :]
        
        batch_embeds = bpe_embedding + wordpiece_embedding + unigram_embedding
        combined_embeds.append(batch_embeds)

    inputs_embeds = torch.cat(combined_embeds, dim=0)
    # print(f"Final inputs_embeds shape: {inputs_embeds.shape}")
    
    if attention_mask is not None:
        attention_mask = attention_mask[:, :inputs_embeds.shape[1]]
    
    return self.original_forward(
        input_ids=None,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        cache_position=cache_position,
        labels=labels,
        num_logits_to_keep=num_logits_to_keep,
        **kwargs
    )
    
model.original_forward = model.forward
model.forward = types.MethodType(modified_forward, model)


In [16]:
# Freeze everething, except the new embeddings
for param in model.parameters():
    param.requires_grad = False

for param in model.model.extra_embedding_1.parameters():
    param.requires_grad = True

for param in model.model.extra_embedding_2.parameters():
    param.requires_grad = True


In [6]:
#check if the embeddings are trainable
for name, param in model.named_parameters():
    print(name, param.requires_grad)


model.embed_tokens.weight False
model.layers.0.self_attn.q_proj.weight False
model.layers.0.self_attn.k_proj.weight False
model.layers.0.self_attn.v_proj.weight False
model.layers.0.self_attn.o_proj.weight False
model.layers.0.mlp.gate_proj.weight False
model.layers.0.mlp.up_proj.weight False
model.layers.0.mlp.down_proj.weight False
model.layers.0.input_layernorm.weight False
model.layers.0.post_attention_layernorm.weight False
model.layers.1.self_attn.q_proj.weight False
model.layers.1.self_attn.k_proj.weight False
model.layers.1.self_attn.v_proj.weight False
model.layers.1.self_attn.o_proj.weight False
model.layers.1.mlp.gate_proj.weight False
model.layers.1.mlp.up_proj.weight False
model.layers.1.mlp.down_proj.weight False
model.layers.1.input_layernorm.weight False
model.layers.1.post_attention_layernorm.weight False
model.layers.2.self_attn.q_proj.weight False
model.layers.2.self_attn.k_proj.weight False
model.layers.2.self_attn.v_proj.weight False
model.layers.2.self_attn.o_proj

In [7]:
from datasets import load_dataset

ds = load_dataset("mlabonne/FineTome-Alpaca-100k", split="train")


In [8]:
def tokenize(examples, tokenizer):
    texts = [f"### Instruction: {instruction}\n### Response: {output}" 
             for instruction, output in zip(examples['instruction'], examples['output'])]
    
    tokenized = tokenizer(
        texts,
        truncation=True,
        max_length=1024,
        padding="max_length",
        return_tensors=None
    )
        
    # Add labels for causal language modeling
    tokenized["labels"] = [ids.copy() for ids in tokenized["input_ids"]]
    return tokenized

In [9]:
from transformers import PreTrainedTokenizerFast
from transformers import AutoTokenizer

wordpiece_tokenizer = AutoTokenizer.from_pretrained("trained_tokenizers/wordpiece")
unigram_tokenizer = AutoTokenizer.from_pretrained("trained_tokenizers/unigram")

In [11]:
text = "Can you tell me a joke?"

original_tokenizer.pad_token_id = original_tokenizer.eos_token_id
unigram_tokenizer.pad_token_id = unigram_tokenizer.eos_token_id
wordpiece_tokenizer.pad_token_id = wordpiece_tokenizer.eos_token_id

original_tokens = original_tokenizer(text, padding='max_length', max_length=10)
wordpiece_tokens = wordpiece_tokenizer(text, padding='max_length', max_length=10)
unigram_tokens = unigram_tokenizer(text, padding='max_length', max_length=10)

combined_text = original_tokens["input_ids"] + [200000] + wordpiece_tokens["input_ids"] + [200000] + unigram_tokens["input_ids"]
combined_attention_mask = original_tokens['attention_mask'] + [200000] + wordpiece_tokens['attention_mask'] + [200000] + unigram_tokens['attention_mask']

In [12]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 768)
    (layers): ModuleList(
      (0-11): 12 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (o_proj): Linear(in_features=768, out_features=768, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=768, out_features=2048, bias=False)
          (up_proj): Linear(in_features=768, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=768, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((768,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((768,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((768,), eps=1e-05)
    (rotary_emb): LlamaRotaryEm

In [None]:
import datasets
import gc


original_tokenizer.pad_token = original_tokenizer.eos_token
wordpiece_tokenizer.pad_token = wordpiece_tokenizer.eos_token
unigram_tokenizer.pad_token = unigram_tokenizer.eos_token

orig_tokenized_text = ds.map(lambda examples: tokenize(examples, original_tokenizer), batched=True, remove_columns=['instruction',"source","score",'output'])
orig_tokenized_text = orig_tokenized_text.rename_column("input_ids", "input_ids1")
orig_tokenized_text = orig_tokenized_text.rename_column("attention_mask", "attention_mask1")
orig_tokenized_text = orig_tokenized_text.rename_column("labels", "labels1")

wordpiece_tokenized_text = ds.map(lambda examples: tokenize(examples, wordpiece_tokenizer), batched=True, remove_columns=['instruction',"source","score",'output'])
wordpiece_tokenized_text = wordpiece_tokenized_text.rename_column("input_ids", "input_ids2")
wordpiece_tokenized_text = wordpiece_tokenized_text.rename_column("attention_mask", "attention_mask2")
wordpiece_tokenized_text = wordpiece_tokenized_text.rename_column("labels", "labels2")

unigram_tokenized_text = ds.map(lambda examples: tokenize(examples, unigram_tokenizer), batched=True, remove_columns=['instruction',"source","score",'output'])
unigram_tokenized_text = unigram_tokenized_text.rename_column("input_ids", "input_ids3")
unigram_tokenized_text = unigram_tokenized_text.rename_column("attention_mask", "attention_mask3")
unigram_tokenized_text = unigram_tokenized_text.rename_column("labels", "labels3")

# Unify the tokenized datasets
wordpiece_tokenized_text = wordpiece_tokenized_text.remove_columns(['token_type_ids'])
unigram_tokenized_text = unigram_tokenized_text.remove_columns(['token_type_ids'])

# Unify the tokenized datasets
unified_tokenized_text = datasets.concatenate_datasets([orig_tokenized_text, wordpiece_tokenized_text, unigram_tokenized_text], axis=1)


del orig_tokenized_text
del wordpiece_tokenized_text
del unigram_tokenized_text

gc.collect()

30

In [None]:
def concat_rows(row, max_length=1024):
    sep = [200000]

    combined_input_ids = row['input_ids1'] + sep + row['input_ids2'] + sep + row['input_ids3']
    combined_attention_mask = row['attention_mask1'] + sep + row['attention_mask2'] + sep + row['attention_mask3']
    combined_labels = row['labels1'] 

    return {
        'input_ids': combined_input_ids,
        'attention_mask': combined_attention_mask,
        'labels': combined_labels
    }

unified_tokenized_text = unified_tokenized_text.map(
    lambda row: concat_rows(row, max_length=1024),
    remove_columns=['input_ids1', 'input_ids2', 'input_ids3', 'attention_mask1', 'attention_mask2', 'attention_mask3', 'labels1', 'labels2', 'labels3']
)


In [13]:
print(len(unified_tokenized_text[0]['input_ids']))
print(len(unified_tokenized_text[0]['labels']))
print(len(unified_tokenized_text[0]['attention_mask']))

3074
1024
3074


In [17]:
from torch.utils.data import DataLoader
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

model.train()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

optimizer = torch.optim.AdamW([
        {"params": model.model.extra_embedding_1.parameters(), "lr": 2e-5},
        {"params": model.model.extra_embedding_2.parameters(), "lr": 2e-5}
])

def multi_tokenizer_data_collator(features):
    """
    Custom data collator for handling input from multiple tokenizers
    with separator tokens (200000) preserved.
    """
    # Input IDs with batch dimension
    input_ids = torch.stack([torch.tensor(f["input_ids"]) for f in features])
    attention_mask = torch.stack([torch.tensor(f["attention_mask"]) for f in features])
    labels = torch.stack([torch.tensor(f["labels"]) for f in features])

    # print(input_ids.shape, labels.shape, attention_mask.shape)
    
    return {
        "input_ids": input_ids,  
        "attention_mask": attention_mask,
        "labels": labels
    }

# Create the data collator
data_collator = multi_tokenizer_data_collator

# Define training arguments
training_args = TrainingArguments(
    output_dir="./trained_embeddings",
    per_device_train_batch_size=2,  
    num_train_epochs=3,
    logging_dir="./logs",
    save_strategy="steps",
    save_steps=5000,
    learning_rate=2e-5,
    remove_unused_columns=False,  
    fp16=True,  
    optim="adamw_torch",
    logging_strategy="steps",
    logging_steps=1000,  # Log every 100 steps
    disable_tqdm=False,  # Ensure tqdm is enabled
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=unified_tokenized_text,
    data_collator=data_collator,
    optimizers=(optimizer, None)  
)


trainer.train()


  return self.original_forward(


Step,Training Loss
1000,8.9393
2000,8.5757
3000,8.4654
4000,8.3877
5000,8.3617
6000,8.3143
7000,8.3172
8000,8.2761
9000,8.2371
10000,8.2109


  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(
  return self.original_forward(


TrainOutput(global_step=150000, training_loss=8.084710865885416, metrics={'train_runtime': 19986.5399, 'train_samples_per_second': 15.01, 'train_steps_per_second': 7.505, 'total_flos': 4.700666760192e+17, 'train_loss': 8.084710865885416, 'epoch': 3.0})

In [18]:
trainer.model.save_pretrained("saved_model")