In [13]:
from NitMultilingualEncoders import NitQwenMathInstruct, NitMT5encoder, NitRobertaencoder

In [None]:
from AlignmentModels import Conv1dAutoencoder, CNN1DRBencoder, CNN1DRBdecoder

In [15]:
import numpy as np
import pandas as pd

In [16]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

In [17]:
from torch.nn.utils import clip_grad_norm_

In [18]:
# torch.cuda.set_device(4)

In [19]:
torch.set_float32_matmul_precision('high')

In [20]:
from tqdm import tqdm

In [21]:
from datasets import load_dataset

In [22]:
import re
def extract_num_output(text):
    match = re.search(r'(?<=The answer is:\s).*$', text)
    if match:
        return match.group(0)
    return None

In [23]:
data_dir = "data_cache"
model_dir = "model_cache"
math_ds = load_dataset("meta-math/MetaMathQA", cache_dir=data_dir)
sen_ds = load_dataset("sentence-transformers/wikipedia-en-sentences",cache_dir=data_dir)

In [24]:
math_df = math_ds['train'].to_pandas()
sen_df = sen_ds['train'].to_pandas()

In [25]:
math_df.head()

Unnamed: 0,type,query,original_question,response
0,MATH_AnsAug,Gracie and Joe are choosing numbers on the com...,Gracie and Joe are choosing numbers on the com...,"The distance between two points $(x_1,y_1)$ an..."
1,GSM_Rephrased,What is the total cost of purchasing equipment...,The treasurer of a football team must buy equi...,"Each player requires a $25 jersey, a $15.20 pa..."
2,GSM_SV,Diego baked 12 cakes for his sister's birthday...,Diego baked 12 cakes for his sister's birthday...,"To solve this problem, we need to determine th..."
3,MATH_AnsAug,Convert $10101_3$ to a base 10 integer.,Convert $10101_3$ to a base 10 integer.,$10101_3 = 1 \cdot 3^4 + 0 \cdot 3^3 + 1 \cdot...
4,GSM_FOBAR,"Sue works in a factory and every 30 minutes, a...","Sue works in a factory and every 30 minutes, a...","We know that every 30 minutes, a machine produ..."


In [26]:
sen_df.head()

Unnamed: 0,sentence
0,"The film stars M. G. Ramachandran, Latha, Anja..."
1,Naarda plenirena is a species of moth in the f...
2,Sponsored by the American Federation of Labor ...
3,Since that election the Belfast Corporation Ac...
4,It was also included on their Best of Volume 1.


In [27]:
math_df['Numerical_output']= math_df['response'].apply(extract_num_output)

In [28]:
math_df.head()

Unnamed: 0,type,query,original_question,response,Numerical_output
0,MATH_AnsAug,Gracie and Joe are choosing numbers on the com...,Gracie and Joe are choosing numbers on the com...,"The distance between two points $(x_1,y_1)$ an...",\sqrt{5}
1,GSM_Rephrased,What is the total cost of purchasing equipment...,The treasurer of a football team must buy equi...,"Each player requires a $25 jersey, a $15.20 pa...",752
2,GSM_SV,Diego baked 12 cakes for his sister's birthday...,Diego baked 12 cakes for his sister's birthday...,"To solve this problem, we need to determine th...",1
3,MATH_AnsAug,Convert $10101_3$ to a base 10 integer.,Convert $10101_3$ to a base 10 integer.,$10101_3 = 1 \cdot 3^4 + 0 \cdot 3^3 + 1 \cdot...,91
4,GSM_FOBAR,"Sue works in a factory and every 30 minutes, a...","Sue works in a factory and every 30 minutes, a...","We know that every 30 minutes, a machine produ...",1


In [29]:
math_df.isna().sum()

type                 0
query                0
original_question    0
response             0
Numerical_output     0
dtype: int64

In [30]:
sen_df.isna().sum()

sentence    0
dtype: int64

# Data Loaders

In [31]:
batch_size = 1024

In [None]:
math_train_query = math_df['query']
sen_train = sen_df['sentence'].sample(n=500000, random_state=42)  

In [33]:
math_array = math_train_query.to_numpy()
sen_array = sen_train.to_numpy()

In [34]:
math_X_train, math_X_test = train_test_split(math_array, test_size=0.2, random_state=42)
sen_X_train, sen_X_test = train_test_split(sen_array, test_size=0.2, random_state=42)

In [35]:
combined_X_train = np.concatenate((math_X_train, sen_X_train), axis=0)

In [36]:
math_train_loader = DataLoader(math_X_train, batch_size=batch_size, shuffle=True)
math_test_loader = DataLoader(math_X_test, batch_size=batch_size, shuffle=False)

In [37]:
sen_train_loader = DataLoader(sen_X_train, batch_size=batch_size, shuffle=True)
sen_test_loader = DataLoader(sen_X_test, batch_size=batch_size, shuffle=False)

In [38]:
combined_train_loader = DataLoader(combined_X_train, batch_size=batch_size, shuffle=True)

# LLM and Encoder init

In [39]:
max_tokens = 100
padding = "max_length"

In [40]:
qwen_template = "<|im_start|>{text}<|im_end|>"

In [41]:
gpu_ids = [0,1,2,3]

In [42]:
qwen = NitQwenMathInstruct(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)
qwen.setTemplate(qwen_template)

In [None]:
rb = NitRobertaencoder(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)



In [None]:
rb.useDataParellel(gpu_ids)
qwen.useDataParellel(gpu_ids)

In [None]:
rb_embedding_shape = rb.getEmbedding_shape()
qwen_embedding_shape = qwen.getEmbedding_shape()

In [None]:
rb_embedding_shape, qwen_embedding_shape

((100, 768), (100, 1536))

# Alignment Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
align_model = CNN1DRBdecoder(rb_embedding_shape, qwen_embedding_shape).to(device)

In [None]:
align_model = torch.nn.DataParallel(align_model, device_ids=gpu_ids)

# Testing NITModels with embeddings

In [None]:
test_texts = ["how many legs do 400 dogs have, if each dog has 4 legs?"]

In [None]:
qwen.applyTemplate(test_texts)

['<|im_start|>how many legs do 400 dogs have, if each dog has 4 legs?<|im_end|>']

In [None]:
qwen.get_embeddings_from_text(test_texts).input_embeds

tensor([[[-0.0175, -0.0038, -0.0056,  ...,  0.0272, -0.0130,  0.0221],
         [ 0.0146,  0.0037, -0.0320,  ...,  0.0228, -0.0055,  0.0388],
         [ 0.0096, -0.0099,  0.0216,  ...,  0.0417, -0.0024,  0.0309],
         ...,
         [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142],
         [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142],
         [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142]]],
       device='cuda:0', grad_fn=<GatherBackward>)

In [None]:
rb.applyFormating(test_texts)

['how many legs do  4 0 0 dogs have, if each dog has  4 legs?']

In [None]:
rb_embs = rb.get_embeddings_from_text(test_texts, pooler_output=True).input_embeds

In [None]:
rb_embs.shape

torch.Size([1, 100, 768])

In [None]:
rb_embs[0]

tensor([[-0.0274,  0.0641, -0.0252,  ..., -0.0356, -0.0638, -0.0433],
        [-0.0557, -0.4102, -0.0342,  ..., -0.2139,  0.2019, -0.2752],
        [ 0.1365,  0.0160,  0.2783,  ...,  0.1263, -0.0668, -0.7157],
        ...,
        [ 0.0561,  0.0581,  0.0339,  ...,  0.1656,  0.0106, -0.0732],
        [ 0.0561,  0.0581,  0.0339,  ...,  0.1656,  0.0106, -0.0732],
        [ 0.0561,  0.0581,  0.0339,  ...,  0.1656,  0.0106, -0.0732]],
       device='cuda:0', grad_fn=<SelectBackward0>)

# Training

In [None]:
# define train loop
def train(model,encoder,llm ,train_loader, test_loader, optimizer, criterion ,epochs=10, scaler=None, clip_value=1.0):
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        val_loss = 0

        for batch in tqdm(train_loader, desc = f'epoch_{epoch+1}/{epochs}'):
            
            with torch.no_grad():
                encoder_embedding = encoder.get_embeddings_from_text(batch).input_embeds
                llm_embedding = llm.get_embeddings_from_text(batch).input_embeds

            
            # Use AMP for mixed precision if scaler is provided
            # with torch.amp.autocast("cuda", enabled=(scaler is not None)):
            output = model(encoder_embedding)
            loss = criterion(output, llm_embedding)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Numerical instability detected in training. Skipping this batch.")
                continue

            # Backpropagation with optional scaler
            if scaler:
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                # Gradient clipping
                scaler.unscale_(optimizer)
                clip_grad_norm_(model.parameters(), clip_value)
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.zero_grad()
                loss.backward()
                clip_grad_norm_(model.parameters(), clip_value)
                optimizer.step()
                
            train_loss += loss.item() * output.size(0)

        train_loss /= len(train_loader.dataset)
        

        model.eval()
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Validation :"):
                encoder_embedding = encoder.get_embeddings_from_text(batch).input_embeds
                llm_embedding = llm.get_embeddings_from_text(batch).input_embeds

                output = model(encoder_embedding)
                loss = criterion(output, llm_embedding)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"Numerical instability detected in validation. Skipping this batch.")
                    continue

                val_loss += loss.item() * output.size(0)
                
            val_loss /= len(test_loader.dataset)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}")

In [None]:
optimizer = torch.optim.Adam(align_model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
# scaler = torch.amp.GradScaler("cuda")

In [None]:
# train(align_model, rb, qwen, sen_train_loader, sen_test_loader, optimizer, criterion, epochs=2, scaler=None)

In [None]:
train(align_model, rb, qwen, combined_train_loader, math_test_loader, optimizer, criterion, epochs=15, scaler=None)

epoch_1/15: 100%|██████████| 700/700 [02:51<00:00,  4.08it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


Epoch 1/15, Train Loss: 0.0002061, Val Loss: 0.0002400


epoch_2/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


Epoch 2/15, Train Loss: 0.0001547, Val Loss: 0.0001919


epoch_3/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


Epoch 3/15, Train Loss: 0.0001337, Val Loss: 0.0001632


epoch_4/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.26it/s]


Epoch 4/15, Train Loss: 0.0001202, Val Loss: 0.0001458


epoch_5/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.21it/s]


Epoch 5/15, Train Loss: 0.0001108, Val Loss: 0.0001330


epoch_6/15: 100%|██████████| 700/700 [02:49<00:00,  4.12it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.25it/s]


Epoch 6/15, Train Loss: 0.0001044, Val Loss: 0.0001225


epoch_7/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.25it/s]


Epoch 7/15, Train Loss: 0.0000994, Val Loss: 0.0001162


epoch_8/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.21it/s]


Epoch 8/15, Train Loss: 0.0000952, Val Loss: 0.0001112


epoch_9/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


Epoch 9/15, Train Loss: 0.0000916, Val Loss: 0.0001059


epoch_10/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.24it/s]


Epoch 10/15, Train Loss: 0.0000884, Val Loss: 0.0001033


epoch_11/15: 100%|██████████| 700/700 [02:49<00:00,  4.12it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.22it/s]


Epoch 11/15, Train Loss: 0.0000856, Val Loss: 0.0000996


epoch_12/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.25it/s]


Epoch 12/15, Train Loss: 0.0000833, Val Loss: 0.0000970


epoch_13/15: 100%|██████████| 700/700 [02:50<00:00,  4.11it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.23it/s]


Epoch 13/15, Train Loss: 0.0000812, Val Loss: 0.0000949


epoch_14/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.22it/s]


Epoch 14/15, Train Loss: 0.0000792, Val Loss: 0.0000930


epoch_15/15: 100%|██████████| 700/700 [02:49<00:00,  4.13it/s]
Validation :: 100%|██████████| 78/78 [00:18<00:00,  4.26it/s]

Epoch 15/15, Train Loss: 0.0000775, Val Loss: 0.0000911





In [None]:
def save_model(model, filename='model.pth'):
    torch.save(model.state_dict(), filename)
    print(f"Model saved to {filename}")

In [None]:
save_model(align_model, filename='qwen_rb.pth')

Model saved to qwen_rb.pth


: 