# Laser sentence Embeddings to Qwen-math token constructor

In [None]:
%pip install laser_encoders

# Imports

In [None]:
import numpy as np
import pandas as pd

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
from datasets import load_dataset

In [None]:
from laser_encoders import LaserEncoderPipeline

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
from tqdm import tqdm

In [None]:
import re

In [None]:
import os

# Matrix Multiplication Configurations

In [None]:
# Used for high precision matrix multiplication in GPUs like A100
# Comment out if not using supported GPU
torch.set_float32_matmul_precision('high')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
number_of_gpus = torch.cuda.device_count()

# Constants

In [None]:
M  = 1000000

# Parameters

In [None]:
run_name = "laser2qwen"

In [None]:
cache_dir = "data_cache"
model_dir = "model_cache"
checkpoints = 'checkpoints'
logs= "logs"
trained_model = "trained_model"

In [None]:
sentence_slice = 5*M

In [None]:
batch_size = 1024

In [None]:
max_tokens = 100

In [None]:
max_gpus = 4  
max_gpus = min(max_gpus, number_of_gpus)
gpus = list(range(max_gpus))

In [None]:
print(f"Using {max_gpus} GPUs. Gpus: {gpus}")

In [None]:
def create_directories(*dirs):
    for directory in dirs:
        os.makedirs(directory, exist_ok=True)
    print(f"Directories {dirs} ensured.")

In [None]:
create_directories(cache_dir, model_dir, checkpoints, logs)

# Dataset Loading and Preprocessing

In [None]:
def extract_num_output(text):
    match = re.search(r'(?<=The answer is:\s).*$', text)
    if match:
        return match.group(0)
    return None

In [None]:
meta_math_ds = load_dataset("meta-math/MetaMathQA", cache_dir=cache_dir)
# Sentence dataset is used to get more english data
sen_ds = load_dataset("sentence-transformers/wikipedia-en-sentences",cache_dir=cache_dir)

In [None]:
meta_math_df = meta_math_ds['train'].to_pandas()
sen_df = sen_ds['train'].to_pandas()

In [None]:
meta_math_df['num_output'] = meta_math_df['output'].apply(extract_num_output)

# Data Loaders

In [None]:
math_query = meta_math_df["query"].values
sen_query = sen_df["sentence"].sample(n=sentence_slice, random_state=42).values

In [None]:
X_train_math, X_test = train_test_split(math_query, test_size=0.3, random_state=42)

In [None]:
X_train = np.concatenate([X_train_math, sen_query])

In [None]:
train_loader = DataLoader(X_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(X_test, batch_size=batch_size, shuffle=False)

# Model loading

In [None]:
laser = LaserEncoderPipeline(lang="eng_Latn")

In [None]:
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-1.5B-Instruct", cache_dir=model_dir, padding_side='right')
qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Math-1.5B-Instruct", cache_dir = model_dir).to(device)

# Functions to get Embeddings

In [None]:
qwen_embedding_layer=qwen_model.get_input_embeddings()

In [None]:
#Function to get sentence Embeddings from Laser
def get_laser_embeddings(texts):
    return laser.encode_sentences(texts)

In [None]:
def get_qwen_embeddings(texts):
    template = "<|im_start|>{text}<|im_end|>"
    texts = [template.format(text=text) for text in texts]
    tokens = qwen_tokenizer(
            texts,
            return_tensors='pt',
            truncation=True,
            padding='max_length',
            max_length=max_tokens
        ).to(device)
    with torch.no_grad():
        embeddings = qwen_embedding_layer(tokens.input_ids)
    return embeddings, tokens.attention_mask
    

Testing embeddings Functions

In [None]:
test_texts = ["What is 2+2?", "What is 3+3?", "What is 4+4?"]

In [None]:
laser_test_embeddings = get_laser_embeddings(test_texts)

In [None]:
laser_test_embeddings.shape

In [None]:
qwen_test_embeddings = get_qwen_embeddings(test_texts)

In [None]:
qwen_test_embeddings.shape

In [None]:
laser_embeddings_shape = laser_test_embeddings.shape[1]
qwen_embeddings_shape = qwen_test_embeddings[0].shape[1]

In [None]:
laser_embeddings_shape, qwen_embeddings_shape

# Custom Loss Functions

In [None]:
def get_losses(criteria, outputs, targets, target_attention_mask, weight, eps=1e-6):
    pad_attention_mask = 1-target_attention_mask

    attention_norm = (target_attention_mask.sum()+pad_attention_mask.sum())/(target_attention_mask.sum()+eps)
    pad_norm = (target_attention_mask.sum()+pad_attention_mask.sum())/(target_attention_mask.sum()+eps)

    attention_targets = targets * target_attention_mask.unsqueeze(-1)
    pad_targets = targets * pad_attention_mask.unsqueeze(-1)

    attention_outputs = outputs * target_attention_mask.unsqueeze(-1)
    pad_outputs = outputs * pad_attention_mask.unsqueeze(-1)

    attention_loss = criteria(attention_outputs, attention_targets)
    pad_loss = criteria(pad_outputs, pad_targets)

    weighted_loss = attention_norm*attention_loss * weight + pad_norm*pad_loss * (1-weight)

    return weighted_loss, attention_loss, pad_loss

# Functions to save and load Checkpoints

In [None]:
def save_checkpoint(epoch, model, optimizer, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, f"{checkpoints}/{path}")

In [None]:
def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(f"{checkpoints}/{path}")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return epoch, model, optimizer, loss

# Functions to Freeze and Unfreeze Models

In [None]:
def freeeze_model(model, freeze = True):
    for param in model.parameters():
        param.requires_grad = not(freeze)

In [None]:
def freeze_qwen(freeze = True):
    freeeze_model(qwen_model, freeze)

In [None]:
def freeze_laser(freeze = True):
    freeeze_model(laser.model, freeze)

# Alignment Models

In [None]:
# LSTM Decoder Model
class lstmDecoder(nn.Module):
    def __init__(self, output_size, hidden_size, num_layers, dropout):
        super(lstmDecoder, self).__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(output_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x, prev_state):
        # x shape: (batch_size, 1, output_size)
        # prev_state shapes: (num_layers, batch_size, hidden_size)
        output, state = self.lstm(x, prev_state)
        # output shape: (batch_size, 1, hidden_size)
        # state shapes: (num_layers, batch_size, hidden_size)
        output = self.fc(output[:, -1, :])
        # output shape: (batch_size, output_size)
        return output, state

In [None]:
class lstmARDecoder(nn.Module):
    def __init__(self, output_size, hidden_size, max_tokens,num_layers=1, dropout=0):
        super(lstmARDecoder, self).__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.max_tokens = max_tokens
        self.decoder = lstmDecoder(output_size, hidden_size, num_layers, dropout)

    def forward(self, source, target=None, teacher_forcing_ratio = 0.5):
        batch_size = source.size(0)
        device = source.device

        cell = source

        decoder_input = torch.zeros(batch_size, 1, self.output_size, device=device)
        hidden = torch.zeros(self.decoder.num_layers, batch_size, self.decoder.hidden_size, device=device)

        outputs = []

        # Teacher forcing
        for t in range(self.max_tokens):
            decoder_output, (hidden, cell) = self.decoder(decoder_input, (hidden, cell))
            outputs.append(decoder_output)
            
            if target is not None:
                # Teacher forcing
                teacher_force = torch.rand(1).item() < teacher_forcing_ratio
                decoder_input = target[:, t].unsqueeze(1) if teacher_force else decoder_output.unsqueeze(1)
            else:
                # Inference mode
                decoder_input = decoder_output.unsqueeze(1)
        
        outputs = torch.stack(outputs, dim=1)
        # outputs shape: (batch_size, target_len, output_size)
        
        return outputs
        

In [None]:
#CNN Reconstructor Model
class AdvancedSeqReconstructor(nn.Module):
    def __init__(self, compressed_dim, target_dim, kernel_size):
        super(AdvancedSeqReconstructor, self).__init__()
        self.kernel_size = kernel_size
        self.padding_size = (kernel_size - 1) // 2
        
        self.reconstructor = nn.Sequential(
            # First upsampling: compressed_dim → target_dim//4
            nn.ConvTranspose1d(compressed_dim, target_dim//4, kernel_size=self.kernel_size, padding=self.padding_size),
            nn.BatchNorm1d(target_dim//4),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.1),
            
            # Second upsampling: target_dim//4 → target_dim//2
            nn.ConvTranspose1d(target_dim//4, target_dim//2, kernel_size=self.kernel_size, padding=self.padding_size),
            # nn.BatchNorm1d(target_dim//2),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.1),
            
            # Final upsampling: target_dim//2 → target_dim
            nn.ConvTranspose1d(target_dim//2, target_dim, kernel_size=self.kernel_size, padding=self.padding_size),
            nn.Tanh(),
        )

    def forward(self, x):
        # Transpose for ConvTranspose1d operation
        x = x.transpose(1, 2)  # (batch_size, compressed_dim, sequence_length)
        
        # Apply reconstruction
        x = self.reconstructor(x)
        
        # Transpose back to original format
        return x.transpose(1, 2)  # (batch_size, sequence_length, target_dim)

In [None]:
#Wrapper_model to wrap the LSTM and CNN models
class CNNwrapper(nn.Module):
    def __init__(self, sentence_dim, reduced_dim, token_dim, max_tokens, num_lstm_layer = 1, lstm_dropout = 0, cnn_kernel_size = 1):
        super(CNNwrapper, self).__init__()
        self.lstm_model = lstmARDecoder(sentence_dim, reduced_dim, max_tokens, num_lstm_layer, lstm_dropout)
        self.cnn_model = AdvancedSeqReconstructor(reduced_dim, token_dim, cnn_kernel_size)
    
    def forward(self, source, target=None, teacher_forcing_ratio = 0.5):
        lstm_output = self.lstm_model(source, target, teacher_forcing_ratio)
        cnn_output = self.cnn_model(lstm_output)
        return cnn_output
        

# Training Loop

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20, w=0.9):
    
    print(f"Using device: {device}")

    best_val_loss = np.inf

    # Move the model to the selected device (GPU or CPU)
    model.to(device)

    teacher_forcing_ratio = 0.82

    for epoch in range(num_epochs):
        
        teacher_forcing_ratio = max(0, teacher_forcing_ratio - 0.02)
        
        model.train()

        train_loss = 0.0
        train_attention_loss = 0.0
        train_pad_loss = 0.0

        val_loss = 0.0
        val_attention_loss = 0.0
        val_pad_loss = 0.0

        model_saved_at_epoch = False

        # Training phase
        for inputs in tqdm(train_loader, desc=f'epoch_{epoch+1}/{num_epochs}'):

            laser_embeddings = get_laser_embeddings(inputs)
            embeddings, attention_mask = get_qwen_embeddings(inputs)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(laser_embeddings, embeddings, attention_mask=attention_mask)
            weighted_loss, attention_loss, pad_loss = get_losses(criterion, outputs, embeddings, attention_mask, w)

            if torch.isnan(weighted_loss) or torch.isinf(weighted_loss):
                print(f"Numerical instability detected in training. Skipping this batch.")
                continue

            # Backward pass and optimization
            weighted_loss.backward()
            optimizer.step()

            train_loss += weighted_loss.item() * embeddings.size(0)  # Accumulate training loss
            train_attention_loss += attention_loss.item() * embeddings.size(0)
            train_pad_loss += pad_loss.item() * embeddings.size(0)

        # Validation phase
        model.eval()
        with torch.no_grad():
            for inputs in tqdm(val_loader):
                embeddings, attention_mask = get_qwen_embeddings(inputs)

                outputs = model(embeddings, embeddings, attention_mask=attention_mask)
                weighted_loss, attention_loss, pad_loss = get_losses(criterion, outputs, embeddings, attention_mask, w)

                val_loss += weighted_loss.item() * embeddings.size(0)  # Accumulate validation loss
                val_attention_loss += attention_loss.item() * embeddings.size(0)
                val_pad_loss += pad_loss.item() * embeddings.size(0)

        # Calculate average losses
        train_loss /= len(train_loader.dataset)
        train_attention_loss /= len(train_loader.dataset)
        train_pad_loss /= len(train_loader.dataset)

        val_loss /= len(val_loader.dataset)
        val_attention_loss /= len(val_loader.dataset)
        val_pad_loss /= len(val_loader.dataset)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(epoch, model, optimizer, best_val_loss, "best_model.pth")
            print(f"Best model saved with loss: {best_val_loss:.7f} at epoch {epoch+1}/{num_epochs}")
            model_saved_at_epoch = True

        # Print losses
        log_line = (f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}, "
                    f"Train Attention Loss: {train_attention_loss:.7f}, Val Attention Loss: {val_attention_loss:.7f}, "
                    f"Train Pad Loss: {train_pad_loss:.7f}, Val Pad Loss: {val_pad_loss:.7f}")
        print(log_line)
        with open(f"{logs}/logs.txt", "a") as log_file:
            log_file.write(log_line + f" model_saved {model_saved_at_epoch}" + "\n")


# Model Training

In [None]:
# Function to save the Model
def save_model(model, path):
    torch.save(model.state_dict(), path)

In [None]:
#Function to count the number of parameters in the model
def count_parameters(model: nn.Module):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable_params, non_trainable_params

In [None]:
a_model = CNNwrapper(laser_embeddings_shape, 100, qwen_embeddings_shape, max_tokens)

In [None]:
# Loss function
criterion = nn.MSELoss()

# Optimizer (Adam)
optimizer = torch.optim.AdamW(a_model.parameters(), lr=1e-3, weight_decay=1e-2)

In [None]:
count_parameters(a_model)

In [None]:
train_model(a_model, train_loader, test_loader, criterion, optimizer, num_epochs=30)

In [None]:
save_model(a_model, f"{trained_model}/{run_name}.pth")