In [1]:
from NitMultilingualEncoders import NitQwenMathInstruct, NitMT5encoder

In [2]:
from AlignmentModels import Conv1dAutoencoder

In [3]:
import numpy as np
import pandas as pd

In [4]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

In [5]:
from tqdm import tqdm

In [6]:
from datasets import load_dataset

In [7]:
import re
def extract_num_output(text):
    match = re.search(r'(?<=The answer is:\s).*$', text)
    if match:
        return match.group(0)
    return None

In [8]:
data_dir = "data_cache"
model_dir = "model_cache"
math_ds = load_dataset("meta-math/MetaMathQA", cache_dir=data_dir)
sen_ds = load_dataset("sentence-transformers/wikipedia-en-sentences",cache_dir=data_dir)

In [9]:
math_df = math_ds['train'].to_pandas()
sen_df = sen_ds['train'].to_pandas()

In [10]:
math_df.head()

Unnamed: 0,type,query,original_question,response
0,MATH_AnsAug,Gracie and Joe are choosing numbers on the com...,Gracie and Joe are choosing numbers on the com...,"The distance between two points $(x_1,y_1)$ an..."
1,GSM_Rephrased,What is the total cost of purchasing equipment...,The treasurer of a football team must buy equi...,"Each player requires a $25 jersey, a $15.20 pa..."
2,GSM_SV,Diego baked 12 cakes for his sister's birthday...,Diego baked 12 cakes for his sister's birthday...,"To solve this problem, we need to determine th..."
3,MATH_AnsAug,Convert $10101_3$ to a base 10 integer.,Convert $10101_3$ to a base 10 integer.,$10101_3 = 1 \cdot 3^4 + 0 \cdot 3^3 + 1 \cdot...
4,GSM_FOBAR,"Sue works in a factory and every 30 minutes, a...","Sue works in a factory and every 30 minutes, a...","We know that every 30 minutes, a machine produ..."


In [11]:
sen_df.head()

Unnamed: 0,sentence
0,"The film stars M. G. Ramachandran, Latha, Anja..."
1,Naarda plenirena is a species of moth in the f...
2,Sponsored by the American Federation of Labor ...
3,Since that election the Belfast Corporation Ac...
4,It was also included on their Best of Volume 1.


In [12]:
math_df['Numerical_output']= math_df['response'].apply(extract_num_output)

In [13]:
math_df.head()

Unnamed: 0,type,query,original_question,response,Numerical_output
0,MATH_AnsAug,Gracie and Joe are choosing numbers on the com...,Gracie and Joe are choosing numbers on the com...,"The distance between two points $(x_1,y_1)$ an...",\sqrt{5}
1,GSM_Rephrased,What is the total cost of purchasing equipment...,The treasurer of a football team must buy equi...,"Each player requires a $25 jersey, a $15.20 pa...",752
2,GSM_SV,Diego baked 12 cakes for his sister's birthday...,Diego baked 12 cakes for his sister's birthday...,"To solve this problem, we need to determine th...",1
3,MATH_AnsAug,Convert $10101_3$ to a base 10 integer.,Convert $10101_3$ to a base 10 integer.,$10101_3 = 1 \cdot 3^4 + 0 \cdot 3^3 + 1 \cdot...,91
4,GSM_FOBAR,"Sue works in a factory and every 30 minutes, a...","Sue works in a factory and every 30 minutes, a...","We know that every 30 minutes, a machine produ...",1


In [14]:
math_df.isna().sum()

type                 0
query                0
original_question    0
response             0
Numerical_output     0
dtype: int64

In [15]:
sen_df.isna().sum()

sentence    0
dtype: int64

# Data Loaders

In [16]:
batch_size = 1024

In [17]:
math_train_query = math_df['query']
sen_train = sen_df['sentence']

In [18]:
math_array = math_train_query.to_numpy()
sen_array = sen_train.to_numpy()

In [19]:
math_X_train, math_X_test = train_test_split(math_array, test_size=0.2, random_state=42)
sen_X_train, sen_X_test = train_test_split(sen_array, test_size=0.2, random_state=42)

In [20]:
math_train_loader = DataLoader(math_X_train, batch_size=batch_size, shuffle=True)
math_test_loader = DataLoader(math_X_test, batch_size=batch_size, shuffle=False)

In [21]:
sen_train_loader = DataLoader(sen_X_train, batch_size=batch_size, shuffle=True)
sen_test_loader = DataLoader(sen_X_test, batch_size=batch_size, shuffle=False)

# LLM and Encoder init

In [22]:
max_tokens = 100
padding = "max_length"

In [23]:
qwen = NitQwenMathInstruct(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)

In [24]:
mt5 = NitMT5encoder(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)



In [25]:
mt5_embedding_shape = mt5.getEmbedding_shape()
qwen_embedding_shape = qwen.getEmbedding_shape()

# Alignment Model

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [27]:
align_model = Conv1dAutoencoder(mt5_embedding_shape, qwen_embedding_shape).to(device)

# Testing NITModels with embeddings

In [34]:
test_texts = ["What?"]

In [47]:
qwen.get_embeddings_from_text(test_texts).input_embeds[0]

tensor([[ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142],
        [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142],
        [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142],
        ...,
        [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142],
        [ 0.0248,  0.0420,  0.0310,  ...,  0.0300,  0.0059, -0.0067],
        [ 0.0376,  0.0029, -0.0327,  ..., -0.0299, -0.0018, -0.0033]],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [52]:
qwen.get_inputIds(test_texts).input_ids[0]

tensor([151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
        151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,   3838,
            30], device='cuda:0')

In [53]:
mt5.tokenizer.eos_token_id

1

In [55]:
qwen.tokenizer.pad_token_id

151643

In [51]:
mt5.get_inputIds(test_texts).input_ids[0]

tensor([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0, 5126,  291,    1], device='cuda:0')

# Training

In [30]:
# define train loop
def train(model, train_loader, test_loader, optimizer, criterion ,epochs=10):
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        val_loss = 0

        for batch in tqdm(train_loader, desc = f'epoch_{epoch+1}/{epochs}'):
            
            with torch.no_grad():
                mt5_embedding = mt5.get_embeddings(batch).input_embeds
                qwen_embedding = qwen.get_embeddings(batch).input_embeds

            optimizer.zero_grad()
            output = model(mt5_embedding)
            loss = criterion(output, qwen_embedding)
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * output.size(0)

        train_loss /= len(train_loader)
        

        model.eval()
        with torch.no_grad():
            for batch in test_loader:
                mt5_embedding = mt5.get_embeddings(batch).input_embeds
                qwen_embedding = qwen.get_embeddings(batch).input_embeds

                output = model(mt5_embedding)
                loss = criterion(output, qwen_embedding)

                val_loss += loss.item() * output.size(0)
                
            val_loss /= len(test_loader)
        print(f"Epoch {epoch+1}/epochs, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}")