In [1]:
from NitMultilingualEncoders import NitQwenMathInstruct, NitMT5encoder, NitRobertaencoder

In [2]:
from AlignmentModels import Conv1dAutoencoder, CNN1DRBencoder, CNN1DRBdecoder, LSTMDecoder, TransformerDecoderModel

In [3]:
import numpy as np
import pandas as pd

In [4]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

In [5]:
from torch.nn.utils import clip_grad_norm_

In [6]:
# torch.cuda.set_device(4)

In [7]:
torch.set_float32_matmul_precision('high')

In [8]:
from tqdm import tqdm

In [9]:
from datasets import load_dataset

In [10]:
import re
def extract_num_output(text):
    match = re.search(r'(?<=The answer is:\s).*$', text)
    if match:
        return match.group(0)
    return None

In [11]:
data_dir = "data_cache"
model_dir = "model_cache"
math_ds = load_dataset("meta-math/MetaMathQA", cache_dir=data_dir)
sen_ds = load_dataset("sentence-transformers/wikipedia-en-sentences",cache_dir=data_dir)

In [12]:
math_df = math_ds['train'].to_pandas()
sen_df = sen_ds['train'].to_pandas()

In [13]:
math_df.head()

Unnamed: 0,type,query,original_question,response
0,MATH_AnsAug,Gracie and Joe are choosing numbers on the com...,Gracie and Joe are choosing numbers on the com...,"The distance between two points $(x_1,y_1)$ an..."
1,GSM_Rephrased,What is the total cost of purchasing equipment...,The treasurer of a football team must buy equi...,"Each player requires a $25 jersey, a $15.20 pa..."
2,GSM_SV,Diego baked 12 cakes for his sister's birthday...,Diego baked 12 cakes for his sister's birthday...,"To solve this problem, we need to determine th..."
3,MATH_AnsAug,Convert $10101_3$ to a base 10 integer.,Convert $10101_3$ to a base 10 integer.,$10101_3 = 1 \cdot 3^4 + 0 \cdot 3^3 + 1 \cdot...
4,GSM_FOBAR,"Sue works in a factory and every 30 minutes, a...","Sue works in a factory and every 30 minutes, a...","We know that every 30 minutes, a machine produ..."


In [14]:
sen_df.head()

Unnamed: 0,sentence
0,"The film stars M. G. Ramachandran, Latha, Anja..."
1,Naarda plenirena is a species of moth in the f...
2,Sponsored by the American Federation of Labor ...
3,Since that election the Belfast Corporation Ac...
4,It was also included on their Best of Volume 1.


In [15]:
math_df['Numerical_output']= math_df['response'].apply(extract_num_output)

In [16]:
math_df.head()

Unnamed: 0,type,query,original_question,response,Numerical_output
0,MATH_AnsAug,Gracie and Joe are choosing numbers on the com...,Gracie and Joe are choosing numbers on the com...,"The distance between two points $(x_1,y_1)$ an...",\sqrt{5}
1,GSM_Rephrased,What is the total cost of purchasing equipment...,The treasurer of a football team must buy equi...,"Each player requires a $25 jersey, a $15.20 pa...",752
2,GSM_SV,Diego baked 12 cakes for his sister's birthday...,Diego baked 12 cakes for his sister's birthday...,"To solve this problem, we need to determine th...",1
3,MATH_AnsAug,Convert $10101_3$ to a base 10 integer.,Convert $10101_3$ to a base 10 integer.,$10101_3 = 1 \cdot 3^4 + 0 \cdot 3^3 + 1 \cdot...,91
4,GSM_FOBAR,"Sue works in a factory and every 30 minutes, a...","Sue works in a factory and every 30 minutes, a...","We know that every 30 minutes, a machine produ...",1


In [17]:
math_df.isna().sum()

type                 0
query                0
original_question    0
response             0
Numerical_output     0
dtype: int64

In [18]:
sen_df.isna().sum()

sentence    0
dtype: int64

# Data Loaders

In [19]:
batch_size = 1024

In [20]:
math_train_query = math_df['query']
sen_train = sen_df['sentence']

In [21]:
math_array = math_train_query.to_numpy()
sen_array = sen_train.to_numpy()

In [22]:
math_X_train, math_X_test = train_test_split(math_array, test_size=0.2, random_state=42)
sen_X_train, sen_X_test = train_test_split(sen_array, test_size=0.2, random_state=42)

In [23]:
combined_X_train = np.concatenate((math_X_train, sen_X_train), axis=0)

In [24]:
math_train_loader = DataLoader(math_X_train, batch_size=batch_size, shuffle=True)
math_test_loader = DataLoader(math_X_test, batch_size=batch_size, shuffle=False)

In [25]:
sen_train_loader = DataLoader(sen_X_train, batch_size=batch_size, shuffle=True)
sen_test_loader = DataLoader(sen_X_test, batch_size=batch_size, shuffle=False)

In [26]:
combined_train_loader = DataLoader(combined_X_train, batch_size=batch_size, shuffle=True)

# LLM and Encoder init

In [27]:
max_tokens = 100
padding = "max_length"

In [28]:
qwen_template = "<|im_start|>{text}<|im_end|>"

In [29]:
# gpu_ids = [0,1,2,3]

In [30]:
qwen = NitQwenMathInstruct(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)
qwen.setTemplate(qwen_template)

In [31]:
rb = NitRobertaencoder(cache_dir=model_dir, max_tokens=max_tokens, padding=padding)

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
# rb.useDataParellel(gpu_ids)
# qwen.useDataParellel(gpu_ids)

In [33]:
rb_embedding_shape = rb.getEmbedding_shape()
qwen_embedding_shape = qwen.getEmbedding_shape()

In [34]:
rb_embedding_shape, qwen_embedding_shape

((100, 768), (100, 1536))

# Alignment Model

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [36]:
align_model = CNN1DRBencoder(rb_embedding_shape, qwen_embedding_shape).to(device)
# align_model = TransformerDecoderModel().to(device)

In [37]:
# align_model = torch.nn.DataParallel(align_model, device_ids=gpu_ids)
align_model = CNN1DRBdecoder(rb_embedding_shape, qwen_embedding_shape).to(device)

# Testing NITModels with embeddings

In [38]:
test_texts = ["how many legs do 400 dogs have, if each dog has 4 legs?"]

In [39]:
qwen.applyTemplate(test_texts)

['<|im_start|>how many legs do 400 dogs have, if each dog has 4 legs?<|im_end|>']

In [40]:
qwen.get_embeddings_from_text(test_texts).input_embeds

tensor([[[-0.0175, -0.0038, -0.0056,  ...,  0.0272, -0.0130,  0.0221],
         [ 0.0146,  0.0037, -0.0320,  ...,  0.0228, -0.0055,  0.0388],
         [ 0.0096, -0.0099,  0.0216,  ...,  0.0417, -0.0024,  0.0309],
         ...,
         [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142],
         [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142],
         [ 0.0408,  0.0123,  0.0154,  ...,  0.0070,  0.0210, -0.0142]]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

In [41]:
rb.applyFormating(test_texts)

['how many legs do  4 0 0 dogs have, if each dog has  4 legs?']

In [42]:
rb_embs = rb.get_embeddings_from_text(test_texts, pooler_output=True).input_embeds

In [43]:
rb_embs.shape

torch.Size([1, 768])

In [44]:
rb_embs[0]

tensor([-3.0541e-01, -1.5441e-02,  3.4156e-02, -3.2357e-01, -3.1736e-01,
        -1.1502e-02,  1.7016e-02, -1.1527e-01,  1.2176e-01, -5.1969e-02,
         1.3977e-01, -3.7667e-02, -1.2690e-01, -5.4021e-02, -2.3008e-01,
        -6.3373e-02,  2.5877e-02,  7.4595e-02,  2.9210e-01, -7.5359e-02,
         8.3690e-04, -4.0202e-03, -1.4141e-02,  4.2198e-01, -6.4608e-02,
        -2.4139e-02,  3.6583e-01,  2.3151e-01,  1.1857e-01,  1.0236e-01,
        -1.8374e-02, -8.6124e-02,  2.0123e-01,  2.6544e-01, -1.5505e-01,
        -3.7288e-02,  1.1489e-03,  1.0705e-01, -7.4041e-02, -1.9716e-01,
         3.7293e-02, -2.4643e-01, -1.6754e-01,  3.3686e-01,  1.1518e-01,
        -1.9700e-01,  6.7309e-02, -3.0770e-01, -4.5063e-01,  5.0742e-01,
        -2.3951e-01,  1.7839e-01, -7.4241e-02,  4.1142e-01,  1.0115e-01,
         4.3448e-01, -3.2894e-02,  2.6248e-01, -4.9366e-02, -6.3946e-02,
        -3.8566e-02,  1.1402e-01, -6.0186e-02, -3.9147e-01, -1.2564e-01,
         1.1212e-01,  1.2654e-01, -1.9248e-01,  2.0

# Training

In [45]:
def save_checkpoint(epoch, model, optimizer, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, "checkpoints/"+path)

In [None]:
# define train loop
def train(model,encoder,llm ,train_loader, test_loader, optimizer, criterion ,epochs=10, scaler=None, clip_value=1.0):
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        val_loss = 0
        best_val_loss = np.inf

        for batch in tqdm(train_loader, desc = f'epoch_{epoch+1}/{epochs}'):
            
            with torch.no_grad():
                encoder_embedding = encoder.get_embeddings_from_text(batch, pooler_output=True).input_embeds
                llm_embedding = llm.get_embeddings_from_text(batch).input_embeds

            
            # Use AMP for mixed precision if scaler is provided
            # with torch.amp.autocast("cuda", enabled=(scaler is not None)):
            output = model(encoder_embedding)
            loss = criterion(output, llm_embedding)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Numerical instability detected in training. Skipping this batch.")
                continue

            # Backpropagation with optional scaler
            if scaler:
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                # Gradient clipping
                scaler.unscale_(optimizer)
                clip_grad_norm_(model.parameters(), clip_value)
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.zero_grad()
                loss.backward()
                clip_grad_norm_(model.parameters(), clip_value)
                optimizer.step()
                
            train_loss += loss.item() * output.size(0)

        train_loss /= len(train_loader.dataset)
        

        model.eval()
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Validation :"):
                encoder_embedding = encoder.get_embeddings_from_text(batch, pooler_output=True).input_embeds
                llm_embedding = llm.get_embeddings_from_text(batch).input_embeds

                output = model(encoder_embedding)
                loss = criterion(output, llm_embedding)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"Numerical instability detected in validation. Skipping this batch.")
                    continue

                val_loss += loss.item() * output.size(0)
                
            val_loss /= len(test_loader.dataset)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                save_checkpoint(epoch, model, optimizer, best_val_loss, "best_model.pth")
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}")
        with open("logs.txt", "a") as log_file:
            log_file.write(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}\n")

In [47]:
optimizer = torch.optim.Adam(align_model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
# scaler = torch.amp.GradScaler("cuda")

In [48]:
# train(align_model, rb, qwen, sen_train_loader, sen_test_loader, optimizer, criterion, epochs=2, scaler=None)

In [None]:
train(align_model, rb, qwen, math_train_loader, math_test_loader, optimizer, criterion, epochs=100, scaler=None, clip_value=1.0)

epoch_1/10: 100%|██████████| 309/309 [01:56<00:00,  2.65it/s]
Validation :: 100%|██████████| 78/78 [00:27<00:00,  2.83it/s]


Epoch 1/10, Train Loss: 0.0004087, Val Loss: 0.0003859


epoch_2/10:   3%|▎         | 8/309 [00:03<02:06,  2.39it/s]


KeyboardInterrupt: 

In [None]:
def save_model(model, filename='model.pth'):
    torch.save(model.module.state_dict(), filename)
    print(f"Model saved to {filename}")

In [None]:
save_model(align_model, filename='qwen_rb_pooler.pth')