In [23]:
from NitMultilingualEncoders import NitQwenMathInstruct, NitMT5encoder, NitRobertaencoder

In [24]:
from AlignmentModels import Conv1dAutoencoder, CNN1DRBencoder

In [25]:
import torch

In [26]:
def load_model(model, filename='model.pth'):
    state_dict = torch.load(filename, weights_only=True)
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict)
    model.eval()  # Set the model to evaluation mode
    print(f"Model loaded from {filename}")

In [27]:
model_dir = "model_cache"

In [28]:
max_tokens = 100
padding = "max_length"

In [29]:
qwen_template = "<|im_start|>{text}<|im_end|>"

In [30]:
qwen = NitQwenMathInstruct(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)
qwen.setTemplate(qwen_template)

In [31]:
rb = NitRobertaencoder(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
rb_embedding_shape = rb.getEmbedding_shape()
qwen_embedding_shape = qwen.getEmbedding_shape()

In [33]:
rb_embedding_shape, qwen_embedding_shape

((100, 768), (100, 1536))

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [35]:
alignment_model = CNN1DRBencoder(rb_embedding_shape, qwen_embedding_shape).to(device)

In [36]:
load_model(alignment_model, "qwen_rb.pth")

Model loaded from qwen_rb.pth


In [142]:
test_texts = ["There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?"]

In [143]:
original_embeddings = qwen.get_embeddings_from_text(test_texts)

In [144]:
rb_embeddings = rb.get_embeddings_from_text(test_texts)

In [145]:
altered_embeddings = alignment_model(rb_embeddings.input_embeds)

In [146]:
qwen.get_generation(original_embeddings)

[' on Monday, Tuesday, Wednesday, and Thursday, 5 more computers are installed each day. Therefore, the total number of computers installed from Monday to Thursday is:\n\\[ 5 \\times 4 = 20 \\]\n\nAdding the initial number of computers to the total number of computers installed gives:\n\\[ 9 + 20 = 29 \\]\n\nSo, the total number of computers now in the server room is:\n\\[ \\boxed{29} \\]']

In [147]:
rb_embeddings.input_embeds = altered_embeddings

In [148]:
qwen.get_generation(rb_embeddings)

[" added\nTo determine the total number of birds in the game box after a certain number of days, we need to follow these steps:\n\n1. Identify the initial number of birds in the game box.\n2. Determine the number of birds added each day.\n3. Calculate the total number of birds added over the specified period.\n4. Add the initial number of birds to the total number of birds added.\n\nLet's go through each step in detail:\n\n1. **Initial number of birds:**\n   The game box initially contains 6 birds.\n\n2. **Number of birds added each day:**\n   Each day, 5 more birds are added to the game box.\n\n3. **Period of time:**\n   The addition of birds starts from Friday and continues through to the next Wednesday. This period includes:\n   - Friday\n   - Saturday\n   - Sunday\n   - Monday\n   - Tuesday\n   - Wednesday\n\n   Therefore, the total number of days is 6 days.\n\n4. **Total number of birds added:**\n   Since 5 birds are added each day for 6 days, we calculate the total number of bird