In [None]:
from m2_cw import load_qwen
import torch
import torch.nn as nn
import copy
from bidict import bidict
from transformers.tokenization_utils_base import BatchEncoding

model, tokenizer = load_qwen(reduce_vocabulary=False)
device = "cpu"
model.to(device)

valid_words = list("0123456789,.;")
print(valid_words)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ',', '.', ';']


In [None]:
new_model = copy.deepcopy(model)
# Work out qwens tokens for this vocabulary
qwen_tokens = []
new_tokens = []
embedding_vectors = []

token = 0
for word in valid_words:
    # Get qwen's token for the word
    qwen_token_tensor = tokenizer(word, return_tensors="pt", add_special_tokens=False).input_ids[0]
    qwen_token = qwen_token_tensor.item()
    qwen_tokens.append(qwen_token)

    # define my token for the word
    new_tokens.append(token)
    token += 1

    # Get qwens embedding vector for the word
    embedding_vector = model.model.embed_tokens(qwen_token_tensor)
    embedding_vectors.append(embedding_vector)

token_map = bidict({qwen: forecast for qwen, forecast in zip(qwen_tokens, new_tokens)})


# Make the embedding vectors into torch parameters
embedding_vectors = torch.concat(embedding_vectors)
embedding_vectors = nn.Parameter(embedding_vectors)


# make new embedding that only has tokens from our allowed words list
new_embedding = nn.Embedding(num_embeddings=len(new_tokens), embedding_dim=model.model.embed_tokens.embedding_dim)
new_embedding.weight = embedding_vectors
new_model.model.embed_tokens = new_embedding

IndexError: index out of range in self

In [None]:
def convert_tokens_to(desired_type: str,
                      tokens: BatchEncoding, 
                      map: bidict=token_map):

    # Make lookup for the conversion
    if desired_type == "forecast":
        print("converting to forecast")
        max_key = max(token_map.keys())
        print(max_key)
        lookup = torch.full((max_key + 1,), -1)  # use -1 or some default for unmapped
        for k, v in token_map.items():
            lookup[k] = v

    elif desired_type == "qwen":
        print("converting to qwen")
        max_key = max(token_map.values())
        print(max_key)
        lookup = torch.full((max_key + 1,), -1)
        for k, v in token_map.items():
            lookup[v] = k

    new_input_ids = lookup[tokens.input_ids]
    new_tokens = copy.deepcopy(tokens)
    new_tokens["input_ids"] = new_input_ids
    return new_tokens

string = ["1.23,8.12;", "9.87,7.43;"]

qwen_tokens = tokenizer(string, return_tensors="pt", add_special_tokens=False)
forecast_tokens = convert_tokens_to("forecast", qwen_tokens)
qwen_tokens_again = convert_tokens_to("qwen", forecast_tokens)

print(qwen_tokens)
print(forecast_tokens)
print(qwen_tokens_again)

converting to forecast
26
converting to qwen
12
{'input_ids': tensor([[16, 13, 17, 18, 11, 23, 13, 16, 17, 26],
        [24, 13, 23, 22, 11, 22, 13, 19, 18, 26]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
{'input_ids': tensor([[ 1, 11,  2,  3, 10,  8, 11,  1,  2, 12],
        [ 9, 11,  8,  7, 10,  7, 11,  4,  3, 12]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
{'input_ids': tensor([[16, 13, 17, 18, 11, 23, 13, 16, 17, 26],
        [24, 13, 23, 22, 11, 22, 13, 19, 18, 26]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [None]:
lm_head = model.lm_head
new_weight = []
new_bias = []
for qwen, forecast in token_map.items():
    weights = lm_head.weight[qwen, :].unsqueeze(0)
    bias = lm_head.bias[qwen]
    new_weight.append(weights)
    new_bias.append(bias)


new_weight = torch.concat(new_weight, dim=0)
new_weight = nn.Parameter(new_weight)
new_bias = torch.tensor(new_bias)
new_bias = nn.Parameter(new_bias)
new_lm_head = nn.Linear(in_features=lm_head.in_features, out_features=new_bias.shape[0], bias=True)
new_lm_head.weight = new_weight
new_lm_head.bias = new_bias

print(new_lm_head)

Linear(in_features=896, out_features=13, bias=True)
