In [1]:
import numpy as np
import pandas as pd

In [2]:
import torch
import torch.nn as nn

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [4]:
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

In [5]:
from tqdm import tqdm

In [6]:
data_dir = "data_cache"
model_dir = "model_cache"

In [7]:
math_df = pd.read_json("hf://datasets/meta-math/MetaMathQA/MetaMathQA-395K.json")

In [8]:
batch_size = 1024

In [9]:
math_train_query = math_df['query']

In [10]:
math_array = math_train_query.to_numpy()

In [11]:
math_X_train, math_X_test = train_test_split(math_array, test_size=0.2, random_state=42)

In [12]:
math_train_loader = DataLoader(math_X_train, batch_size=batch_size, shuffle=True)
math_test_loader = DataLoader(math_X_test, batch_size=batch_size, shuffle=False)

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-1.5B-Instruct", cache_dir=model_dir)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Math-1.5B-Instruct", cache_dir=model_dir).to(device)

In [15]:
qwen_embedding_layer = model.get_input_embeddings()

In [16]:
# gpu_ids = [0,1]

In [17]:
# qwen_embedding_layer = torch.nn.DataParallel(qwen_embedding_layer, device_ids=gpu_ids)

In [18]:
def get_qwen_embeddings(texts):
    template = "<|im_start|>{text}<|im_end|>"
    texts = [template.format(text=text) for text in texts]
    tokens = tokenizer(
            texts,
            return_tensors='pt',
            truncation=True,
            padding=True,
            max_length=200
        ).to(device)
    with torch.no_grad():
        embeddings = qwen_embedding_layer(tokens.input_ids)
    return embeddings, tokens.attention_mask

In [19]:
class AdvancedSeqDimReducer(nn.Module):
    def __init__(self, input_dim, target_dim, kernel_size=1):
        super(AdvancedSeqDimReducer, self).__init__()
        self.kernel_size = kernel_size
        self.padding_size = (kernel_size - 1) // 2
        
        self.reducer = nn.Sequential(
            nn.Conv1d(input_dim, input_dim//2, kernel_size=self.kernel_size, padding=self.padding_size),
            nn.BatchNorm1d(input_dim//2),
            nn.LeakyReLU(0.1),
            # nn.Dropout(0.1),
            
            nn.Conv1d(input_dim//2, input_dim//4, kernel_size=self.kernel_size, padding=self.padding_size),
            nn.BatchNorm1d(input_dim//4),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.1),
            
            nn.Conv1d(input_dim//4, target_dim, kernel_size=self.kernel_size, padding=self.padding_size),
            nn.LeakyReLU(0.1),
        )
    
    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.reducer(x)
        return x.transpose(1, 2)

In [20]:
class AdvancedSeqReconstructor(nn.Module):
    def __init__(self, compressed_dim, target_dim, kernel_size):
        super(AdvancedSeqReconstructor, self).__init__()
        self.kernel_size = kernel_size
        self.padding_size = (kernel_size - 1) // 2
        
        self.reconstructor = nn.Sequential(
            # First upsampling: compressed_dim → target_dim//4
            nn.ConvTranspose1d(compressed_dim, target_dim//4, kernel_size=self.kernel_size, padding=self.padding_size),
            nn.BatchNorm1d(target_dim//4),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.1),
            
            # Second upsampling: target_dim//4 → target_dim//2
            nn.ConvTranspose1d(target_dim//4, target_dim//2, kernel_size=self.kernel_size, padding=self.padding_size),
            nn.BatchNorm1d(target_dim//2),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.1),
            
            # Final upsampling: target_dim//2 → target_dim
            nn.ConvTranspose1d(target_dim//2, target_dim, kernel_size=self.kernel_size, padding=self.padding_size),
            nn.Tanh(),
        )

    def forward(self, x):
        # Transpose for ConvTranspose1d operation
        x = x.transpose(1, 2)  # (batch_size, compressed_dim, sequence_length)
        
        # Apply reconstruction
        x = self.reconstructor(x)
        
        # Transpose back to original format
        return x.transpose(1, 2)  # (batch_size, sequence_length, target_dim)

In [21]:
class CNNAE(nn.Module):
    def __init__(self, input_dim,compressed_dim, target_dim, kernel_size):
        super(CNNAE, self).__init__()
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1) / 2
        self.encoder = AdvancedSeqDimReducer(input_dim,compressed_dim, kernel_size)
        self.decoder = AdvancedSeqReconstructor(compressed_dim,target_dim, kernel_size)

    def forward(self,x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [22]:
amodel=CNNAE(1536,200,1536,1)

In [23]:
amodel.to(device)
# amodel = torch.nn.DataParallel(amodel, device_ids=gpu_ids)

CNNAE(
  (encoder): AdvancedSeqDimReducer(
    (reducer): Sequential(
      (0): Conv1d(1536, 768, kernel_size=(1,), stride=(1,))
      (1): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.1)
      (3): Conv1d(768, 384, kernel_size=(1,), stride=(1,))
      (4): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.1)
      (6): Dropout(p=0.1, inplace=False)
      (7): Conv1d(384, 200, kernel_size=(1,), stride=(1,))
      (8): LeakyReLU(negative_slope=0.1)
    )
  )
  (decoder): AdvancedSeqReconstructor(
    (reconstructor): Sequential(
      (0): ConvTranspose1d(200, 384, kernel_size=(1,), stride=(1,))
      (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.1)
      (3): Dropout(p=0.1, inplace=False)
      (4): ConvTranspose1d(384, 768, kernel_size=(1,), stride=(1,))
      (5): B

In [24]:
def train(model,train_loader, test_loader, optimizer, criterion ,epochs=10):
    
    best_val_loss = np.inf

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        val_loss = 0
        model_saved = False
        

        for batch in tqdm(train_loader, desc = f'epoch_{epoch+1}/{epochs}'):
            
            with torch.no_grad():
                encoder_embedding, _ = get_qwen_embeddings(batch)
                llm_embedding = encoder_embedding

            
            # Use AMP for mixed precision if scaler is provided
            # with torch.amp.autocast("cuda", enabled=(scaler is not None)):
            output = model(encoder_embedding)
            loss = criterion(output, llm_embedding)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Numerical instability detected in training. Skipping this batch.")
                continue

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
                
            train_loss += loss.item() * output.size(0)

        train_loss /= len(train_loader.dataset)
        

        model.eval()
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Validation :"):
                encoder_embedding,_ = get_qwen_embeddings(batch)
                llm_embedding = encoder_embedding

                output = model(encoder_embedding)
                loss = criterion(output, llm_embedding)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"Numerical instability detected in validation. Skipping this batch.")
                    continue

                val_loss += loss.item() * output.size(0)
                
            val_loss /= len(test_loader.dataset)
            if val_loss < best_val_loss:
                best_val_loss = val_loss

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}")

In [25]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()

In [26]:
train(amodel,math_train_loader,math_test_loader,optimizer,criterion)

epoch_1/10: 100%|██████████| 309/309 [01:16<00:00,  4.05it/s]
Validation :: 100%|██████████| 78/78 [00:13<00:00,  5.91it/s]


Epoch 1/10, Train Loss: 0.0647942, Val Loss: 0.0488328


epoch_2/10: 100%|██████████| 309/309 [01:16<00:00,  4.04it/s]
Validation :: 100%|██████████| 78/78 [00:13<00:00,  5.91it/s]


Epoch 2/10, Train Loss: 0.0647937, Val Loss: 0.0486096


epoch_3/10:  15%|█▍        | 46/309 [00:11<01:05,  4.03it/s]


KeyboardInterrupt: 