In [1]:
import numpy as np
import pandas as pd

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
from torch.utils.data import DataLoader, TensorDataset

In [4]:
from datasets import load_dataset

In [5]:
from abc import ABC, abstractmethod
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from transformers.configuration_utils import PretrainedConfig

In [6]:
from tqdm import tqdm

In [7]:
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast
)

In [8]:
from einops import rearrange

In [9]:
import re

In [10]:
from sklearn.model_selection import train_test_split

In [11]:
import os

# Dataset loading and Preprocessing

In [12]:
def extract_num_output(text):
    match = re.search(r'(?<=The answer is:\s).*$', text)
    if match:
        return match.group(0)
    return None

In [13]:
cache_dir = "data_cache"
model_dir = "model_cache"
ds = load_dataset("meta-math/MetaMathQA", cache_dir=cache_dir)

In [14]:
df = pd.DataFrame(ds['train']) 

In [15]:
df['Numerical_output']= df['response'].apply(extract_num_output)

# Main Classes

In [16]:
class TokenizerOutputs:
    def __init__(self, input_ids, attention_mask):
        self.input_ids = input_ids
        self.attention_mask = attention_mask

In [17]:
class EmbeddingsOutputs:
    def __init__(self, input_embeds, attention_mask):
        self.input_embeds = input_embeds
        self.attention_mask = attention_mask

In [18]:
class AlignmentModel(ABC, nn.Module):
    input_shape : tuple
    output_shape : tuple

    def __init__(self,input_shape, output_shape):
        super().__init__()
        self.input_shape = input_shape
        self.output_shape = output_shape

    @abstractmethod
    def forward(self, x):
        pass

    def test(self):
        sample_input  = np.zeros(self.input_shape)
        output = self.forward(sample_input)
        assert output.shape == self.output_shape

In [19]:
class NITConfig(PretrainedConfig):
    modelName = "NITModel"

    def __init__(self,
                 llm_name="Qwen/Qwen-1.5B",
                 llm_dim = 1536,
                 freez_encoder= True,
                 freez_llm = True,
                 alignment_class= None,
                 max_tokens = 200,
                 model_cache_dir = "model_cache",
                 **kwargs):
        super().__init__(**kwargs)
        self.llm = llm_name
        self.freez_encoder = freez_encoder
        self.freez_llm = freez_llm
        self.alignment_class = alignment_class
        self.max_tokens = max_tokens
        self.model_cache =model_cache_dir
        self.llm_dim = llm_dim

In [20]:
class NITModel(PreTrainedModel):
    tokenizer : AutoTokenizer
    llm : AutoModelForCausalLM
    alingment_model : AlignmentModel
    embeddings: nn.Embedding

    def __init__(self, config : NITConfig):
        super().__init__(config)
        self.tokenizer = AutoTokenizer.from_pretrained(config.llm,cache_dir=config.model_cache, padding_side='left')
        self.llm = AutoModelForCausalLM.from_pretrained(config.llm,cache_dir=config.model_cache)
        self.input_shape = (config.max_tokens,config.llm_dim)
        self.output_shape = (config.max_tokens,config.llm_dim)
        self.llm_dim = config.llm_dim
        self.maxTokens = config.max_tokens
        self.alingment_model = config.alignment_class(self.input_shape,self.output_shape)
        self.device_me = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.llm.to(self.device_me)
        # self.alingment_model.to(self.device_me)
        self.embeddings = self.llm.get_input_embeddings()

        if config.freez_llm:
            self.freeze_lm()
    
    def freeze_lm(self):
        for param in self.llm.parameters():
            param.requires_grad = False
    
    def unfreeze_lm(self):
        for param in self.llm.parameters():
            param.requires_grad = True
    
    def get_inputIds(self, texts):
        encoding = self.tokenizer(
            texts,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=self.maxTokens
        )
        output = TokenizerOutputs(encoding["input_ids"], encoding["attention_mask"])
        return output
    
    def get_embeddings(self, inputs : TokenizerOutputs):
        return EmbeddingsOutputs(self.embeddings(inputs.input_ids), inputs.attention_mask)
    
    def get_embeddings_altered(self, inputs : TokenizerOutputs):
        return EmbeddingsOutputs(self.alingment_model(self.get_embeddings(inputs).input_embeds), inputs.attention_mask)
    
    def original_pipeline(self, texts):
        return self.get_embeddings(self.get_inputIds(texts))
    
    def altered_pipeline(self, texts):
        return self.get_embeddings_altered(self.get_inputIds(texts))

    def get_outputs(self, embeddings : EmbeddingsOutputs):
        return self.llm(
            attention_mask=embeddings.attention_mask,
            inputs_embeds=embeddings.input_embeds,
            return_dict=True
        )
    
    def pipeline_outputs(self, texts, altered = False, grad = False):
        if altered:
            return self.get_outputs(self.altered_pipeline(texts))
        return self.get_outputs(self.original_pipeline(texts))
    
    def forward(self , texts):
        
        # TODO: Modify accordingly
        batch_size = len(texts)

        original_output = self.pipeline_outputs(texts)
        altered_outputs = self.pipeline_outputs(texts, altered=True)

        original_logits = original_output.logits
        altered_logits = altered_outputs.logits

        original_prob = F.softmax(original_logits, dim=1)
        original_labels = torch.argmax(original_prob, dim=1)

        loss = F.cross_entropy(altered_logits,original_labels)
        
        # loss = rearrange(loss, '(b s) -> b s', b=batch_size)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=altered_logits,
            hidden_states=altered_outputs.hidden_states,
            attentions=altered_outputs.attentions,
        )

# Params

In [21]:
batch_size = 32
max_tokens = 100
embedding_lenght = 1536

# Dataloaders

In [22]:
split=1
split_size = 395000

In [23]:
X_train_query = df["query"][split_size*(split-1):split_size*(split)]

In [24]:
data_array = X_train_query.to_numpy()

# Split the data into train and test sets
X_train, X_test = train_test_split(data_array, test_size=0.2, random_state=42)

# Create DataLoaders
train_loader = DataLoader(X_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(X_test, batch_size=batch_size, shuffle=False)

# Alignment Model Implementations

In [25]:
class SimpleAlignmentModel(AlignmentModel):
    def __init__(self, input_shape, output_shape):
        super().__init__(input_shape, output_shape)
        # Calculate flattened dimensions for linear transformation
        print("Testing Purpose only Don't use this model.")
        input_dim = input_shape[0] * input_shape[1]
        output_dim = output_shape[0] * output_shape[1]
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(input_dim, 1000)
        self.linear2 = nn.Linear(1000, output_dim)

    def forward(self, x):
        # Flatten input to match linear layer's expected shape
        x = self.flatten(x)
        # Pass through linear layer
        x = self.linear1(x)
        x = self.linear2(x)
        
        # Reshape output to match output_shape
        x = x.view(-1,self.output_shape[0],self.output_shape[1])  
        return x

In [26]:
class LinearAlignementModel(AlignmentModel):
    def __init__(self, input_shape, output_shape):
        super().__init__(input_shape, output_shape)
        input_dim = input_shape[0] * input_shape[1]
        output_dim = output_shape[0] * output_shape[1]
        self.flatten = nn.Flatten()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 1024*2),  # Input size: 200*1536, Output size: 1024
            nn.LeakyReLU(0.1),
            nn.Linear(1024*2,1024),
            # nn.LeakyReLU(0.001),
            # nn.Linear(1024*2, 1024)             # Bottleneck size: 64
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            # nn.Linear(1024, 1024*2),             # Input size: 64, Output size: 256
            # nn.LeakyReLU(0.001),
            nn.Linear(1024,1024*2),
            nn.LeakyReLU(0.1),
            nn.Linear(1024*2, output_dim),    # Output size: 200*1536
            nn.Tanh()                  #
        )
    def forward(self, x):
        x = self.flatten(x)  # Flatten the input tensor
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded.view(-1,self.output_shape[0],self.output_shape[1])   # Reshape to original image size

# Utility Functions of save/load Model/Checkpoints

In [27]:
def save_checkpoint(model, optimizer, epoch, loss, filename='checkpoint.pth'):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss,
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved at {filename}")

In [28]:
def load_checkpoint(model, optimizer, filename='checkpoint.pth'):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Checkpoint loaded from {filename}, starting at epoch {start_epoch} with loss {loss}")
        return start_epoch, loss
    else:
        print(f"No checkpoint found at {filename}")
        return 0, None

In [29]:
# Function to save the model
def save_model(model, filename='model.pth'):
    torch.save(model.state_dict(), filename)
    print(f"Model saved to {filename}")

In [30]:
# Function to load the model
def load_model(model, filename='model.pth'):
    model.load_state_dict(torch.load(filename))
    model.eval()  # Set the model to evaluation mode
    print(f"Model loaded from {filename}")

# Testing Model

In [31]:
class Conv1dAutoencoder(AlignmentModel):
    def __init__(self, input_shape, output_shape):
        super().__init__((100,1536), (100,1536))
        
        # Encoder
        self.encoder = nn.Sequential(
            # First layer to reduce to (100, 800)
            nn.Conv1d(in_channels=1536, out_channels=800, kernel_size=1),
            nn.LeakyReLU(0.1),
            
            nn.Conv1d(in_channels=800, out_channels=400, kernel_size=1),
            nn.LeakyReLU(0.1),
            
            # Second layer to reduce to (100, 100)
            nn.Conv1d(in_channels=400, out_channels=200, kernel_size=1),
            nn.LeakyReLU(0.1),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            
            # First layer to expand back to (100, 800)
            nn.ConvTranspose1d(in_channels=200, out_channels=400, kernel_size=1),
            nn.LeakyReLU(0.1),
            
            nn.ConvTranspose1d(in_channels=400, out_channels=800, kernel_size=1),
            nn.LeakyReLU(0.1),
            
            # Second layer to expand back to (100, 1536)
            nn.ConvTranspose1d(in_channels=800, out_channels=1536, kernel_size=1),
            nn.Tanh()  # Use Tanh to output values in the range [-1, 1]
        )
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(100*200, 2048)
        self.fc2 = nn.Linear(2048, 100*200)
        
    def forward(self, x):
        # Transpose to (batch_size, in_channels, sequence_length)
        x = x.transpose(1, 2)  # Shape becomes (batch_size, 1536, 100)
        
        # Encode
        encoded = self.encoder(x)  # Shape becomes (batch_size, 100, 100)

        encoded_flatten = self.flatten(encoded)

        bottle_neck = self.fc1(encoded_flatten)

        decoded_fc = self.fc2(bottle_neck)

        encoded = decoded_fc.view(-1, 200, 100)
        
        # Decode
        decoded = self.decoder(encoded)  # Shape becomes (batch_size, 1536, 100)
        
        # Transpose back to (batch_size, sequence_length, feature_dimension)
        decoded = decoded.transpose(1, 2)  # Final shape (batch_size, 100, 1536)
        
        return decoded

In [32]:
config = NITConfig(
    llm_name="Qwen/Qwen2.5-Math-1.5B",
    llm_dim=embedding_lenght,
    freeze_encoder=True,
    freez_llm=True,
    alignment_class=Conv1dAutoencoder,
    max_tokens=max_tokens,
    model_cache_dir= model_dir
)

In [33]:
nit_model = NITModel(config)

In [34]:
import pytorch_lightning as pl

In [35]:
class NITModelLightning(pl.LightningModule):
    def __init__(self, model, lr=1e-3):
        super(NITModelLightning, self).__init__()
        self.model = model
        self.lr = lr
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        loss = self.model(batch).loss
        self.log("train_loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.model(batch).loss
        self.log("val_loss", loss.mean(), on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [36]:
model = NITModelLightning.load_from_checkpoint("checkpoints/best-checkpoint.ckpt", model=nit_model)

In [37]:
q = [X_test[0]]

In [38]:
q

['How many ways are there to put 4 distinguishable balls into 2 distinguishable boxes?']

In [40]:
nit_model.to("cpu")

NITModel(
  (llm): Qwen2ForCausalLM(
    (model): Qwen2Model(
      (embed_tokens): Embedding(151936, 1536)
      (layers): ModuleList(
        (0-27): 28 x Qwen2DecoderLayer(
          (self_attn): Qwen2SdpaAttention(
            (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
            (k_proj): Linear(in_features=1536, out_features=256, bias=True)
            (v_proj): Linear(in_features=1536, out_features=256, bias=True)
            (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
            (rotary_emb): Qwen2RotaryEmbedding()
          )
          (mlp): Qwen2MLP(
            (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
            (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
            (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
          (post_attention_layernorm): Qwen

In [44]:
original_outputs = model.model.pipeline_outputs(q)
altered_outputs = model.model.pipeline_outputs(q,altered=True)

In [56]:
altered  = nit_model.altered_pipeline(q)

In [51]:
original = nit_model.original_pipeline(q)

In [57]:
outputs= nit_model.llm.generate(
            input_ids=None,
            inputs_embeds=altered.input_embeds,
            attention_mask=original.attention_mask,
            pad_token_id=nit_model.tokenizer.pad_token_id,  # Padding token ID
            eos_token_id=nit_model.tokenizer.eos_token_id,  # End-of-sequence token ID
            no_repeat_ngram_size=2,
            return_dict=True
        )

In [58]:
nit_model.tokenizer.batch_decode(outputs, skip_special_tokens=True)

[' Finally, we can solve for $x$ by dividing both sides by $-10$:\n\n$x = \\frac{11}{-20}$\n\n$x \\approx -0.55$\n\nSo, the solution to the equation is $ x = -\\frac {13}{2} $ or $ \\color{#DF0030}{x =-6.25} $.']

In [53]:
F.mse_loss(nit_model.altered_pipeline(q).input_embeds, nit_model.original_pipeline(q).input_embeds)

tensor(0.8663, grad_fn=<MseLossBackward0>)

In [None]:
# load_checkpoint(nit_model,optimizer)

  checkpoint = torch.load(filename)


Checkpoint loaded from checkpoint.pth, starting at epoch 0 with loss -7550580.039974684


(0, -7550580.039974684)