In [3]:
from huggingface_hub import login

login(token = "hf_") 


In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("/kaggle/input/model-with-trained-embeddings/saved_model")

Some weights of the model checkpoint at /kaggle/input/model-with-trained-embeddings/saved_model were not used when initializing LlamaForCausalLM: ['model.extra_embedding_1.weight', 'model.extra_embedding_2.weight']
- This IS expected if you are initializing LlamaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlamaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
from transformers import LlamaForCausalLM
import torch.nn as nn

class LlamaWithExtraEmbeddings(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        
        original_vocab_size, embedding_dim = self.model.embed_tokens.weight.shape
        self.model.extra_embedding_1 = nn.Embedding(original_vocab_size, embedding_dim)
        self.model.extra_embedding_2 = nn.Embedding(original_vocab_size, embedding_dim)


In [6]:
model = LlamaWithExtraEmbeddings.from_pretrained("/kaggle/input/model-with-trained-embeddings/saved_model")

In [7]:
model

LlamaWithExtraEmbeddings(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 768)
    (layers): ModuleList(
      (0-11): 12 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (o_proj): Linear(in_features=768, out_features=768, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=768, out_features=2048, bias=False)
          (up_proj): Linear(in_features=768, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=768, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((768,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((768,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((768,), eps=1e-05)
    (rotary_emb): Llama

## Mean of embeddings

In [8]:
from typing import List, Optional, Union
from cachetools import Cache
import types
import torch

def modified_forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
    num_logits_to_keep: int = 0,
    **kwargs
):
    if input_ids is None and inputs_embeds is None:
        raise ValueError("You have to specify either input_ids or inputs_embeds")

    batch_size = input_ids.shape[0]
    combined_embeds = []

    for batch_idx in range(batch_size):
        str_input_ids = " ".join([str(i) for i in input_ids[batch_idx].tolist()])
        input_ids_parts = str_input_ids.split(" 200000 ")

        embed_list = []
        count = 0

        if self.embedding_config.get("use_bpe", False):
            bpe_ids = torch.tensor([list(map(int, input_ids_parts[0].split(" ")))], device=input_ids.device)
            bpe_emb = self.model.embed_tokens(bpe_ids)
            embed_list.append(bpe_emb)
            count += 1

        if self.embedding_config.get("use_wordpiece", False):
            wp_ids = torch.tensor([list(map(int, input_ids_parts[1].split(" ")))], device=input_ids.device)
            wp_emb = self.model.extra_embedding_1(wp_ids)
            embed_list.append(wp_emb)
            count += 1

        if self.embedding_config.get("use_unigram", False):
            uni_ids = torch.tensor([list(map(int, input_ids_parts[2].split(" ")))], device=input_ids.device)
            uni_emb = self.model.extra_embedding_2(uni_ids)
            embed_list.append(uni_emb)
            count += 1

        min_len = min(e.shape[1] for e in embed_list)
        embed_list = [e[:, :min_len, :] for e in embed_list]
        averaged = sum(embed_list) / count
        combined_embeds.append(averaged)

    inputs_embeds = torch.cat(combined_embeds, dim=0)

    if attention_mask is not None:
        attention_mask = attention_mask[:, :inputs_embeds.shape[1]]

    return self.original_forward(
        input_ids=None,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        cache_position=cache_position,
        labels=labels,
        num_logits_to_keep=num_logits_to_keep,
        **kwargs
    )

    
model.original_forward = model.forward
model.forward = types.MethodType(modified_forward, model)

In [9]:
model.embedding_config = {
    "use_bpe": True,         # original
    "use_wordpiece": True,   # wordpiece_tokenizer
    "use_unigram": True      # unigram_tokenizer
}


In [10]:
from datasets import load_dataset

ds = load_dataset("mlabonne/FineTome-Alpaca-100k", split="train")

README.md:   0%|          | 0.00/408 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/89.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [11]:
def tokenize(examples, tokenizer):
    texts = [f"### Instruction: {instruction}\n### Response: {output}" 
             for instruction, output in zip(examples['instruction'], examples['output'])]
    
    tokenized = tokenizer(
        texts,
        truncation=True,
        max_length=1024,
        padding="max_length",
        return_tensors=None
    )
        
    # Add labels for causal language modeling
    tokenized["labels"] = [ids.copy() for ids in tokenized["input_ids"]]
    return tokenized

In [12]:
from transformers import PreTrainedTokenizerFast
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM

original_tokenizer = AutoTokenizer.from_pretrained("nickypro/tinyllama-110M")
wordpiece_tokenizer = AutoTokenizer.from_pretrained("/kaggle/input/nlp-tokenizers/tokenizers/wordpiece_tokenizer")
unigram_tokenizer = AutoTokenizer.from_pretrained("/kaggle/input/nlp-tokenizers/tokenizers/unigram_tokenizer")

tokenizer_config.json:   0%|          | 0.00/686 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.


In [13]:
import datasets
import gc


original_tokenizer.pad_token = original_tokenizer.eos_token
wordpiece_tokenizer.pad_token = wordpiece_tokenizer.eos_token
unigram_tokenizer.pad_token = unigram_tokenizer.eos_token

orig_tokenized_text = ds.map(lambda examples: tokenize(examples, original_tokenizer), batched=True, remove_columns=['instruction',"source","score",'output'])
orig_tokenized_text = orig_tokenized_text.rename_column("input_ids", "input_ids1")
orig_tokenized_text = orig_tokenized_text.rename_column("attention_mask", "attention_mask1")
orig_tokenized_text = orig_tokenized_text.rename_column("labels", "labels1")

wordpiece_tokenized_text = ds.map(lambda examples: tokenize(examples, wordpiece_tokenizer), batched=True, remove_columns=['instruction',"source","score",'output'])
wordpiece_tokenized_text = wordpiece_tokenized_text.rename_column("input_ids", "input_ids2")
wordpiece_tokenized_text = wordpiece_tokenized_text.rename_column("attention_mask", "attention_mask2")
wordpiece_tokenized_text = wordpiece_tokenized_text.rename_column("labels", "labels2")

unigram_tokenized_text = ds.map(lambda examples: tokenize(examples, unigram_tokenizer), batched=True, remove_columns=['instruction',"source","score",'output'])
unigram_tokenized_text = unigram_tokenized_text.rename_column("input_ids", "input_ids3")
unigram_tokenized_text = unigram_tokenized_text.rename_column("attention_mask", "attention_mask3")
unigram_tokenized_text = unigram_tokenized_text.rename_column("labels", "labels3")

# Unify the tokenized datasets
# Remove the 'token_type_ids' column from each dataset to avoid duplication
# orig_tokenized_text = orig_tokenized_text.remove_columns(['token_type_ids'])
wordpiece_tokenized_text = wordpiece_tokenized_text.remove_columns(['token_type_ids'])
unigram_tokenized_text = unigram_tokenized_text.remove_columns(['token_type_ids'])

# Unify the tokenized datasets
unified_tokenized_text = datasets.concatenate_datasets([orig_tokenized_text, wordpiece_tokenized_text, unigram_tokenized_text], axis=1)


del orig_tokenized_text
del wordpiece_tokenized_text
del unigram_tokenized_text
# del unigram_tokenized_text

gc.collect()

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

32

In [14]:
def concat_rows(row, max_length=1024):
    sep = [200000]

    combined_input_ids = row['input_ids1'] + sep + row['input_ids2'] + sep + row['input_ids3']
    combined_attention_mask = row['attention_mask1'] + sep + row['attention_mask2'] + sep + row['attention_mask3']
    combined_labels = row['labels1'] 

    return {
        'input_ids': combined_input_ids,
        'attention_mask': combined_attention_mask,
        'labels': combined_labels
    }


unified_tokenized_text = unified_tokenized_text.map(
    lambda row: concat_rows(row, max_length=1024),
    remove_columns=['input_ids1', 'input_ids2', 'input_ids3', 'attention_mask1', 'attention_mask2', 'attention_mask3', 'labels1', 'labels2', 'labels3']
)

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [15]:

from torch.utils.data import DataLoader
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling


def multi_tokenizer_data_collator(features):
    """
    Custom data collator for handling input from multiple tokenizers
    with separator tokens (200000) preserved.
    """
    # Input IDs with batch dimension
    input_ids = torch.stack([torch.tensor(f["input_ids"]) for f in features])
    attention_mask = torch.stack([torch.tensor(f["attention_mask"]) for f in features])
    labels = torch.stack([torch.tensor(f["labels"]) for f in features])

    # print(input_ids.shape, labels.shape, attention_mask.shape)
    
    return {
        "input_ids": input_ids,  
        "attention_mask": attention_mask,
        "labels": labels
    }

# Create the data collator
data_collator = multi_tokenizer_data_collator


In [16]:
from torch.utils.data import DataLoader

loader = DataLoader(
    unified_tokenized_text,
    batch_size=1,
    collate_fn=multi_tokenizer_data_collator
)

batch = next(iter(loader))


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
batch = {k: v.to(device) for k, v in batch.items()}


In [18]:
model.embedding_config = {
    "use_bpe": True,         # original
    "use_wordpiece": False,   # wordpiece_tokenizer
    "use_unigram": True      # unigram_tokenizer
}


In [19]:
model.eval()
with torch.no_grad():
    output = model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        labels=batch["labels"]
    )

print("Loss:", output.loss.item())
print("Logits shape:", output.logits.shape)


`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Loss: 7.5166850090026855
Logits shape: torch.Size([1, 1024, 32000])


## Concatenation of embeddings

In [20]:
from transformers import LlamaForCausalLM
import torch.nn as nn

class LlamaWithExtraEmbeddings(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.embedding_config = {
            "use_bpe": True,         # original
            "use_wordpiece": True,   # wordpiece_tokenizer
            "use_unigram": True      # unigram_tokenizer
        }

        original_vocab_size, embedding_dim = self.model.embed_tokens.weight.shape
        self.model.extra_embedding_1 = nn.Embedding(original_vocab_size, embedding_dim)
        self.model.extra_embedding_2 = nn.Embedding(original_vocab_size, embedding_dim)
        active_embeddings = sum([
            self.embedding_config.get("use_bpe", False),
            self.embedding_config.get("use_wordpiece", False),
            self.embedding_config.get("use_unigram", False)
        ])
        embedding_dim = self.model.embed_tokens.embedding_dim
        self.embedding_projection = nn.Linear(embedding_dim * active_embeddings, embedding_dim)


In [21]:
model = LlamaWithExtraEmbeddings.from_pretrained("/kaggle/input/model-with-trained-embeddings/saved_model")

Some weights of LlamaWithExtraEmbeddings were not initialized from the model checkpoint at /kaggle/input/model-with-trained-embeddings/saved_model and are newly initialized: ['embedding_projection.bias', 'embedding_projection.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
model

LlamaWithExtraEmbeddings(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 768)
    (layers): ModuleList(
      (0-11): 12 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (o_proj): Linear(in_features=768, out_features=768, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=768, out_features=2048, bias=False)
          (up_proj): Linear(in_features=768, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=768, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((768,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((768,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((768,), eps=1e-05)
    (rotary_emb): Llama

In [23]:
def concat_forward(self, input_ids, attention_mask=None, **kwargs):
    batch_size = input_ids.shape[0]
    combined_embeds = []

    for batch_idx in range(batch_size):
        str_input_ids = " ".join([str(i) for i in input_ids[batch_idx].tolist()])
        input_ids_parts = str_input_ids.split(" 200000 ")

        embed_list = []

        if self.embedding_config.get("use_bpe", False):
            bpe_ids = torch.tensor([list(map(int, input_ids_parts[0].split(" ")))], device=input_ids.device)
            bpe_emb = self.model.embed_tokens(bpe_ids)
            embed_list.append(bpe_emb)

        if self.embedding_config.get("use_wordpiece", False):
            wp_ids = torch.tensor([list(map(int, input_ids_parts[1].split(" ")))], device=input_ids.device)
            wp_emb = self.model.extra_embedding_1(wp_ids)
            embed_list.append(wp_emb)

        if self.embedding_config.get("use_unigram", False):
            uni_ids = torch.tensor([list(map(int, input_ids_parts[2].split(" ")))], device=input_ids.device)
            uni_emb = self.model.extra_embedding_2(uni_ids)
            embed_list.append(uni_emb)

        min_len = min(e.shape[1] for e in embed_list)
        embed_list = [e[:, :min_len, :] for e in embed_list]

        concat_emb = torch.cat(embed_list, dim=-1)

        # Project back
        projected_embeds = self.embedding_projection(concat_emb)
        combined_embeds.append(projected_embeds)

    inputs_embeds = torch.cat(combined_embeds, dim=0)

    if attention_mask is not None:
        attention_mask = attention_mask[:, :inputs_embeds.shape[1]]

    return self.original_forward(
        input_ids=None,
        attention_mask=attention_mask,
        inputs_embeds=inputs_embeds,
        **kwargs
    )


model.original_forward = model.forward
model.forward = types.MethodType(concat_forward, model)

In [24]:
from torch.utils.data import DataLoader

loader = DataLoader(
    unified_tokenized_text,
    batch_size=1,
    collate_fn=multi_tokenizer_data_collator
)

batch = next(iter(loader))


In [25]:
model.eval()
with torch.no_grad():
    output = model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        labels=batch["labels"]
    )

print("Loss:", output.loss.item())
print("Logits shape:", output.logits.shape)


Loss: 12.096101760864258
Logits shape: torch.Size([1, 1024, 32000])


## Weighted average

In [75]:
from transformers import LlamaForCausalLM
import torch.nn as nn

class LlamaWithExtraEmbeddings(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        
        original_vocab_size, embedding_dim = self.model.embed_tokens.weight.shape
        self.model.extra_embedding_1 = nn.Embedding(original_vocab_size, embedding_dim)
        self.model.extra_embedding_2 = nn.Embedding(original_vocab_size, embedding_dim)
        self.embedding_config = {
            "use_bpe": True,
            "use_wordpiece": True,
            "use_unigram": True,
            "weights": {
                "bpe": 0.5,
                "wordpiece": 0.,
                "unigram": 0.5
            }
        }


In [76]:
model = LlamaWithExtraEmbeddings.from_pretrained("/kaggle/input/model-with-trained-embeddings/saved_model")

In [77]:
def weighted_avg_forward(self, input_ids, attention_mask=None, **kwargs):
    batch_size = input_ids.shape[0]
    combined_embeds = []

    weights = self.embedding_config.get("weights", {})
    total_weight = sum(weights.get(key, 0.0) for key in ["bpe", "wordpiece", "unigram"])

    for batch_idx in range(batch_size):
        str_input_ids = " ".join([str(i) for i in input_ids[batch_idx].tolist()])
        input_ids_parts = str_input_ids.split(" 200000 ")

        embed_list = []
        weight_list = []

        if self.embedding_config.get("use_bpe", False):
            bpe_ids = torch.tensor([list(map(int, input_ids_parts[0].split(" ")))], device=input_ids.device)
            bpe_emb = self.model.embed_tokens(bpe_ids)
            embed_list.append(bpe_emb)
            weight_list.append(weights.get("bpe", 0.0))

        if self.embedding_config.get("use_wordpiece", False):
            wp_ids = torch.tensor([list(map(int, input_ids_parts[1].split(" ")))], device=input_ids.device)
            wp_emb = self.model.extra_embedding_1(wp_ids)
            embed_list.append(wp_emb)
            weight_list.append(weights.get("wordpiece", 0.0))

        if self.embedding_config.get("use_unigram", False):
            uni_ids = torch.tensor([list(map(int, input_ids_parts[2].split(" ")))], device=input_ids.device)
            uni_emb = self.model.extra_embedding_2(uni_ids)
            embed_list.append(uni_emb)
            weight_list.append(weights.get("unigram", 0.0))


        min_len = min(e.shape[1] for e in embed_list)
        embed_list = [e[:, :min_len, :] for e in embed_list]


        norm_weights = [w / total_weight for w in weight_list]
        # norm_weights = weight_list

        weighted_emb = sum(w * emb for w, emb in zip(norm_weights, embed_list))
        combined_embeds.append(weighted_emb)


    inputs_embeds = torch.cat(combined_embeds, dim=0)

    if attention_mask is not None:
        attention_mask = attention_mask[:, :inputs_embeds.shape[1]]

    return self.original_forward(
        input_ids=None,
        attention_mask=attention_mask,
        inputs_embeds=inputs_embeds,
        **kwargs
    )


model.original_forward = model.forward
model.forward = types.MethodType(weighted_avg_forward, model)

In [78]:
from torch.utils.data import DataLoader

loader = DataLoader(
    unified_tokenized_text,
    batch_size=1,
    collate_fn=multi_tokenizer_data_collator
)

batch = next(iter(loader))

In [79]:
model.eval()
with torch.no_grad():
    output = model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        labels=batch["labels"]
    )

print("Loss:", output.loss.item())
print("Logits shape:", output.logits.shape)

Loss: 7.512155532836914
Logits shape: torch.Size([1, 1024, 32000])


| BPE Weight | Wordpiece Weight | Unigram Weight | Loss                    |
|------------|------------------|----------------|-------------------------|
| 0.9        | 0.1              | 0.1            | 7.031398773193359       |
| 0.5        | 0.4              | 0.1            | 7.331758499145508       |
| 1.0        | 0.0              | 0.0            | 7.0994086265563965      |
| 0.5        | 0.5              | 0.0            | 8.43027400970459        |
| 0.5        | 0.0              | 0.5            | 7.512155532836914       |
| 1/3        | 1/3              | 1/3            | 8.981038093566895       |
