In [2]:
from NitMultilingualEncoders import NitQwenMathInstruct, NitMT5encoder, NitRobertaencoder

In [3]:
from AlignmentModels import Conv1dAutoencoder, CNN1DRBencoder

In [4]:
import numpy as np
import pandas as pd

In [5]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

In [6]:
from torch.nn.utils import clip_grad_norm_

In [7]:
# torch.cuda.set_device(4)

In [8]:
torch.set_float32_matmul_precision('high')

In [9]:
from tqdm import tqdm

In [10]:
from datasets import load_dataset

In [11]:
import re
def extract_num_output(text):
    match = re.search(r'(?<=The answer is:\s).*$', text)
    if match:
        return match.group(0)
    return None

In [12]:
data_dir = "data_cache"
model_dir = "model_cache"
math_ds = load_dataset("meta-math/MetaMathQA", cache_dir=data_dir)
sen_ds = load_dataset("sentence-transformers/wikipedia-en-sentences",cache_dir=data_dir)

In [13]:
math_df = math_ds['train'].to_pandas()
sen_df = sen_ds['train'].to_pandas()

In [14]:
math_df.head()

Unnamed: 0,type,query,original_question,response
0,MATH_AnsAug,Gracie and Joe are choosing numbers on the com...,Gracie and Joe are choosing numbers on the com...,"The distance between two points $(x_1,y_1)$ an..."
1,GSM_Rephrased,What is the total cost of purchasing equipment...,The treasurer of a football team must buy equi...,"Each player requires a $25 jersey, a $15.20 pa..."
2,GSM_SV,Diego baked 12 cakes for his sister's birthday...,Diego baked 12 cakes for his sister's birthday...,"To solve this problem, we need to determine th..."
3,MATH_AnsAug,Convert $10101_3$ to a base 10 integer.,Convert $10101_3$ to a base 10 integer.,$10101_3 = 1 \cdot 3^4 + 0 \cdot 3^3 + 1 \cdot...
4,GSM_FOBAR,"Sue works in a factory and every 30 minutes, a...","Sue works in a factory and every 30 minutes, a...","We know that every 30 minutes, a machine produ..."


In [15]:
sen_df.head()

Unnamed: 0,sentence
0,"The film stars M. G. Ramachandran, Latha, Anja..."
1,Naarda plenirena is a species of moth in the f...
2,Sponsored by the American Federation of Labor ...
3,Since that election the Belfast Corporation Ac...
4,It was also included on their Best of Volume 1.


In [16]:
math_df['Numerical_output']= math_df['response'].apply(extract_num_output)

In [17]:
math_df.head()

Unnamed: 0,type,query,original_question,response,Numerical_output
0,MATH_AnsAug,Gracie and Joe are choosing numbers on the com...,Gracie and Joe are choosing numbers on the com...,"The distance between two points $(x_1,y_1)$ an...",\sqrt{5}
1,GSM_Rephrased,What is the total cost of purchasing equipment...,The treasurer of a football team must buy equi...,"Each player requires a $25 jersey, a $15.20 pa...",752
2,GSM_SV,Diego baked 12 cakes for his sister's birthday...,Diego baked 12 cakes for his sister's birthday...,"To solve this problem, we need to determine th...",1
3,MATH_AnsAug,Convert $10101_3$ to a base 10 integer.,Convert $10101_3$ to a base 10 integer.,$10101_3 = 1 \cdot 3^4 + 0 \cdot 3^3 + 1 \cdot...,91
4,GSM_FOBAR,"Sue works in a factory and every 30 minutes, a...","Sue works in a factory and every 30 minutes, a...","We know that every 30 minutes, a machine produ...",1


In [18]:
math_df.isna().sum()

type                 0
query                0
original_question    0
response             0
Numerical_output     0
dtype: int64

In [19]:
sen_df.isna().sum()

sentence    0
dtype: int64

# Data Loaders

In [20]:
batch_size = 1024

In [21]:
math_train_query = math_df['query']
sen_train = sen_df['sentence'][0:500000]

In [22]:
combined_train = pd.concat([math_train_query, sen_train], ignore_index=True)

In [23]:
math_array = combined_train .to_numpy()
sen_array = sen_train.to_numpy()

In [24]:
math_X_train, math_X_test = train_test_split(math_array, test_size=0.2, random_state=42)
sen_X_train, sen_X_test = train_test_split(sen_array, test_size=0.2, random_state=42)

In [25]:
math_train_loader = DataLoader(math_X_train, batch_size=batch_size, shuffle=True)
math_test_loader = DataLoader(math_X_test, batch_size=batch_size, shuffle=False)

In [26]:
sen_train_loader = DataLoader(sen_X_train, batch_size=batch_size, shuffle=True)
sen_test_loader = DataLoader(sen_X_test, batch_size=batch_size, shuffle=False)

# LLM and Encoder init

In [27]:
max_tokens = 100
padding = "max_length"

In [28]:
qwen_template = "<|im_start|>{text}<|im_end|>"

In [29]:
gpu_ids = [0,1,2,3]

In [None]:
qwen = NitQwenMathInstruct(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)
qwen.setTemplate(qwen_template)

In [None]:
rb = NitRobertaencoder(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
rb.useDataParellel(gpu_ids)
qwen.useDataParellel(gpu_ids)

In [None]:
rb_embedding_shape = rb.getEmbedding_shape()
qwen_embedding_shape = qwen.getEmbedding_shape()

In [None]:
rb_embedding_shape, qwen_embedding_shape

((100, 768), (100, 1536))

# Alignment Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
align_model = CNN1DRBencoder(rb_embedding_shape, qwen_embedding_shape).to(device)

In [None]:
align_model = torch.nn.DataParallel(align_model, device_ids=gpu_ids)

# Testing NITModels with embeddings

In [None]:
test_texts = ["how many legs do 400 dogs have, if each dog has 4 legs?"]

In [None]:
qwen.applyTemplate(test_texts)

['<|im_start|>how many legs do 400 dogs have, if each dog has 4 legs?<|im_end|>']

In [None]:
qwen.get_embeddings_from_text(test_texts).input_embeds

tensor([[[-0.0175, -0.0038, -0.0056,  ...,  0.0272, -0.0130,  0.0221],
         [ 0.0146,  0.0037, -0.0320,  ...,  0.0228, -0.0055,  0.0388],
         [ 0.0096, -0.0099,  0.0216,  ...,  0.0417, -0.0024,  0.0309],
         ...,
         [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142],
         [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142],
         [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142]]],
       device='cuda:0', grad_fn=<GatherBackward>)

In [None]:
rb.applyFormating(test_texts)

['how many legs do  4 0 0 dogs have, if each dog has  4 legs?']

In [None]:
rb_embs = rb.get_embeddings_from_text(test_texts, hiddenLayer=True).input_embeds

In [None]:
rb_embs.shape

torch.Size([1, 100, 768])

In [None]:
rb_embs[0]

tensor([[-0.0274,  0.0641, -0.0252,  ..., -0.0356, -0.0638, -0.0433],
        [-0.0557, -0.4102, -0.0342,  ..., -0.2139,  0.2019, -0.2752],
        [ 0.1365,  0.0160,  0.2783,  ...,  0.1263, -0.0668, -0.7157],
        ...,
        [ 0.0561,  0.0581,  0.0339,  ...,  0.1656,  0.0106, -0.0732],
        [ 0.0561,  0.0581,  0.0339,  ...,  0.1656,  0.0106, -0.0732],
        [ 0.0561,  0.0581,  0.0339,  ...,  0.1656,  0.0106, -0.0732]],
       device='cuda:0', grad_fn=<SelectBackward0>)

# Training

In [None]:
# define train loop
def train(model,encoder,llm ,train_loader, test_loader, optimizer, criterion ,epochs=10, scaler=None, clip_value=1.0):
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        val_loss = 0

        for batch in tqdm(train_loader, desc = f'epoch_{epoch+1}/{epochs}'):
            
            with torch.no_grad():
                encoder_embedding = encoder.get_embeddings_from_text(batch).input_embeds
                llm_embedding = llm.get_embeddings_from_text(batch).input_embeds

            
            # Use AMP for mixed precision if scaler is provided
            # with torch.amp.autocast("cuda", enabled=(scaler is not None)):
            output = model(encoder_embedding)
            loss = criterion(output, llm_embedding)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Numerical instability detected in training. Skipping this batch.")
                continue

            # Backpropagation with optional scaler
            if scaler:
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                # Gradient clipping
                scaler.unscale_(optimizer)
                clip_grad_norm_(model.parameters(), clip_value)
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.zero_grad()
                loss.backward()
                clip_grad_norm_(model.parameters(), clip_value)
                optimizer.step()
                
            train_loss += loss.item() * output.size(0)

        train_loss /= len(train_loader.dataset)
        

        model.eval()
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Validation :"):
                encoder_embedding = encoder.get_embeddings_from_text(batch).input_embeds
                llm_embedding = llm.get_embeddings_from_text(batch).input_embeds

                output = model(encoder_embedding)
                loss = criterion(output, llm_embedding)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"Numerical instability detected in validation. Skipping this batch.")
                    continue

                val_loss += loss.item() * output.size(0)
                
            val_loss /= len(test_loader.dataset)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}")

In [None]:
optimizer = torch.optim.Adam(align_model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
# scaler = torch.amp.GradScaler("cuda")

In [None]:
# train(align_model, rb, qwen, sen_train_loader, sen_test_loader, optimizer, criterion, epochs=2, scaler=None)

In [None]:
train(align_model, rb, qwen, math_train_loader, math_test_loader, optimizer, criterion, epochs=15, scaler=None)

epoch_1/15: 100%|██████████| 309/309 [01:21<00:00,  3.80it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.24it/s]


Epoch 1/15, Train Loss: 0.0003236, Val Loss: 0.0002717


epoch_2/15: 100%|██████████| 309/309 [01:18<00:00,  3.94it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


Epoch 2/15, Train Loss: 0.0002389, Val Loss: 0.0002130


epoch_3/15: 100%|██████████| 309/309 [01:18<00:00,  3.94it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


Epoch 3/15, Train Loss: 0.0001935, Val Loss: 0.0001804


epoch_4/15: 100%|██████████| 309/309 [01:18<00:00,  3.94it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.28it/s]


Epoch 4/15, Train Loss: 0.0001654, Val Loss: 0.0001579


epoch_5/15: 100%|██████████| 309/309 [01:18<00:00,  3.94it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.25it/s]


Epoch 5/15, Train Loss: 0.0001457, Val Loss: 0.0001430


epoch_6/15: 100%|██████████| 309/309 [01:18<00:00,  3.95it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.25it/s]


Epoch 6/15, Train Loss: 0.0001326, Val Loss: 0.0001335


epoch_7/15: 100%|██████████| 309/309 [01:18<00:00,  3.95it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.29it/s]


Epoch 7/15, Train Loss: 0.0001229, Val Loss: 0.0001247


epoch_8/15: 100%|██████████| 309/309 [01:18<00:00,  3.94it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.25it/s]


Epoch 8/15, Train Loss: 0.0001144, Val Loss: 0.0001183


epoch_9/15: 100%|██████████| 309/309 [01:19<00:00,  3.90it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.24it/s]


Epoch 9/15, Train Loss: 0.0001081, Val Loss: 0.0001130


epoch_10/15: 100%|██████████| 309/309 [01:19<00:00,  3.88it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.18it/s]


Epoch 10/15, Train Loss: 0.0001030, Val Loss: 0.0001101


epoch_11/15: 100%|██████████| 309/309 [01:19<00:00,  3.87it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.22it/s]


Epoch 11/15, Train Loss: 0.0000985, Val Loss: 0.0001052


epoch_12/15: 100%|██████████| 309/309 [01:19<00:00,  3.88it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.20it/s]


Epoch 12/15, Train Loss: 0.0000945, Val Loss: 0.0001016


epoch_13/15: 100%|██████████| 309/309 [01:19<00:00,  3.89it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


Epoch 13/15, Train Loss: 0.0000914, Val Loss: 0.0000994


epoch_14/15: 100%|██████████| 309/309 [01:19<00:00,  3.88it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.20it/s]


Epoch 14/15, Train Loss: 0.0000972, Val Loss: 0.0306302


epoch_15/15: 100%|██████████| 309/309 [01:19<00:00,  3.88it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.20it/s]

Epoch 15/15, Train Loss: 0.0015182, Val Loss: 0.0004271



