In [None]:
from .NitMultilingualEncoders import NitQwenMathInstruct, NitMT5encoder

In [None]:
from .AlignmentModels import Conv1dAutoencoder

In [None]:
import numpy as np
import pandas as pd

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

In [None]:
from tqdm import tqdm

In [None]:
from datasets import load_dataset

In [None]:
import re
def extract_num_output(text):
    match = re.search(r'(?<=The answer is:\s).*$', text)
    if match:
        return match.group(0)
    return None

In [None]:
data_dir = "data_cache"
model_dir = "model_cache"
math_ds = load_dataset("meta-math/MetaMathQA", cache_dir=data_dir)
sen_ds = load_dataset("sentence-transformers/wikipedia-en-sentences",cache_dir=data_dir)

In [None]:
math_df = pd.DataFrame(math_ds['train']) 
sen_df = pd.DataFrame(sen_ds['train'])

In [None]:
math_df.head()

In [None]:
sen_df.head()

In [None]:
math_df['Numerical_output']= math_df['response'].apply(extract_num_output)

In [None]:
math_df.head()

In [None]:
math_df.isna().sum()

In [None]:
sen_df.isna().sum()

# Data Loaders

In [None]:
batch_size = 1024

In [None]:
math_train_query = math_df['query']
sen_train = sen_df['sentence']

In [None]:
math_array = math_train_query.to_numpy()
sen_array = sen_train.to_numpy()

In [None]:
math_X_train, math_X_test = train_test_split(math_array, test_size=0.2, random_state=42)
sen_X_train, sen_X_test = train_test_split(sen_array, test_size=0.2, random_state=42)

In [None]:
math_train_loader = DataLoader(math_X_train, batch_size=batch_size, shuffle=True)
math_test_loader = DataLoader(math_X_test, batch_size=batch_size, shuffle=False)

In [None]:
sen_train_loader = DataLoader(sen_X_train, batch_size=batch_size, shuffle=True)
sen_test_loader = DataLoader(sen_X_test, batch_size=batch_size, shuffle=False)

# LLM and Encoder init

In [None]:
max_tokens = 100
padding = "max_length"

In [None]:
qwen = NitQwenMathInstruct(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)

In [None]:
mt5 = NitMT5encoder(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)

In [None]:
mt5_embedding_shape = mt5.getEmbedding_shape()
qwen_embedding_shape = qwen.getEmbedding_shape()

# Alignment Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
align_model = Conv1dAutoencoder(mt5_embedding_shape, qwen_embedding_shape).to(device)

# Training

In [None]:
# define train loop
def train(model, train_loader, test_loader, optimizer, criterion ,epochs=10):
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        val_loss = 0

        for batch in tqdm(train_loader, desc = f'epoch_{epoch+1}/{epochs}'):
            
            with torch.no_grad():
                mt5_embedding = mt5.get_embeddings(batch).input_embeds
                qwen_embedding = qwen.get_embeddings(batch).input_embeds

            optimizer.zero_grad()
            output = model(mt5_embedding)
            loss = criterion(output, qwen_embedding)
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * output.size(0)

        train_loss /= len(train_loader)
        

        model.eval()
        with torch.no_grad():
            for batch in test_loader:
                mt5_embedding = mt5.get_embeddings(batch).input_embeds
                qwen_embedding = qwen.get_embeddings(batch).input_embeds

                output = model(mt5_embedding)
                loss = criterion(output, qwen_embedding)

                val_loss += loss.item() * output.size(0)
                
            val_loss /= len(test_loader)
        print(f"Epoch {epoch+1}/epochs, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}")