In [1]:
import transformers
from finetune_gsm8kcot_ae import ModelArguments, TrainingArguments, DataArguments

model_args, training_args, data_args = ModelArguments(), TrainingArguments(output_dir="./output"), DataArguments()

In [13]:
from datasets import load_dataset

ds = load_dataset("ankner/gsm8k-CoT")
train_dataset = ds["train"]
eval_dataset = ds["test"]

In [14]:
train_dataset = train_dataset.map(lambda example: {**example, "text": f"{example['question']}\n{example['response']}"})
eval_dataset = eval_dataset.map(lambda example: {**example, "text": f"{example['question']}\n{example['response']}"})

In [7]:
from training_utils import pretrain_tokenize_function
from peft import (
    LoraConfig,
)

model_args = ModelArguments()
training_args = TrainingArguments(output_dir="./output")
data_args = DataArguments()

lora_config = LoraConfig(
    r=model_args.lora_r,
    lora_alpha=model_args.lora_alpha,
    lora_dropout=model_args.lora_dropout,
    bias="none",
    task_type="CAUSAL_LM"
)

In [8]:
from modeling_icae_multi_span import ICAE
model = ICAE(model_args, training_args, lora_config)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Freezing the decoder...
trainable params: 13639680 || all params: 2485278720 || trainable%: 0.5488189268365039
Enabling gradient checkpointing...


In [15]:
from finetune_gsm8kcot_ae import preprocess_function

memory_size = training_args.fixed_mem_size
MEM_TOKENS = list(range(model.vocab_size, model.vocab_size + memory_size))

train_dataset = train_dataset.map(preprocess_function)
eval_dataset = eval_dataset.map(preprocess_function)

In [17]:
train_dataset = train_dataset.select([0])
eval_dataset = eval_dataset.select([0])

In [20]:
from training_utils import pretrain_tokenize_function

train_dataset = train_dataset.map(pretrain_tokenize_function, batched=True, batch_size=1, fn_kwargs={"model": model, "mem": MEM_TOKENS, "lm_ratio": training_args.lm_ratio})

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

["Let's say the first ship had x people.\nThe second ship had 2x people.\nThe third ship had 4x people.\nThe total number of people consumed is 847, which means:\nx + 2x + 4x = 847\n7x = 847\nx = 121"]
{'input_ids': [[128000, 10267, 596, 2019, 279, 1176, 8448, 1047, 865, 1274, 627, 791, 2132, 8448, 1047, 220, 17, 87, 1274, 627, 791, 4948, 8448, 1047, 220, 19, 87, 1274, 627, 791, 2860, 1396, 315, 1274, 27073, 374, 220, 25125, 11, 902, 3445, 512, 87, 489, 220, 17, 87, 489, 220, 19, 87, 284, 220, 25125, 198, 22, 87, 284, 220, 25125, 198, 87, 284, 220, 7994]]}


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

['A robe needs 2 bolts of blue fiber.\nThe amount of white fiber needed is half of the blue fiber.\nHalf of 2 bolts is 1 bolt of white fiber.\nThe total bolts needed is the sum of blue and white fiber.\n2 bolts plus 1 bolt equals 3 bolts.']
{'input_ids': [[128000, 32, 63719, 3966, 220, 17, 49939, 315, 6437, 24722, 627, 791, 3392, 315, 4251, 24722, 4460, 374, 4376, 315, 279, 6437, 24722, 627, 43727, 315, 220, 17, 49939, 374, 220, 16, 32942, 315, 4251, 24722, 627, 791, 2860, 49939, 4460, 374, 279, 2694, 315, 6437, 323, 4251, 24722, 627, 17, 49939, 5636, 220, 16, 32942, 17239, 220, 18, 49939, 13]]}


In [24]:
train_dataset[0].keys()

dict_keys(['question', 'answer', 'response', 'text', 'reasoning_trace', 'input_ids', 'prompt_answer_ids', 'labels'])

In [25]:
train_dataset[0]['reasoning_trace']

"Let's say the first ship had x people.\nThe second ship had 2x people.\nThe third ship had 4x people.\nThe total number of people consumed is 847, which means:\nx + 2x + 4x = 847\n7x = 847\nx = 121"

In [39]:

decoded_text = model.tokenizer.decode(train_dataset[0]['labels'][3:], skip_special_tokens=True)
print(decoded_text)


Let's say the first ship had x people.
The second ship had 2x people.
The third ship had 4x people.
The total number of people consumed is 847, which means:
x + 2x + 4x = 847
7x = 847
x = 121#


In [40]:
!nvidia-smi

Thu Feb 20 05:40:03 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.08             Driver Version: 550.127.08     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               On  |   00000000:06:00.0 Off |                  Off |
| 30%   34C    P8             22W /  300W |     271MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [41]:
from modeling_icae_multi_span import ICAE
model = ICAE(model_args, training_args, lora_config)



Freezing the decoder...
trainable params: 13639680 || all params: 2485278720 || trainable%: 0.5488189268365039
Enabling gradient checkpointing...


In [42]:
import torch

model.load_state_dict(torch.load("output/model_weights.pth"), strict=False)

  model.load_state_dict(torch.load("output/model_weights.pth"), strict=False)


<All keys matched successfully>

In [45]:
import torch
from tqdm import tqdm


def run_inference(model, lines):
    model.eval()
    outputs = []
    print("Running inference")
    with torch.no_grad():
        for line in tqdm(lines):
            print("=========================== START ============================")
            print("Current line: ", line)
            # Tokenize input text
            tokenized_text = model.tokenizer(line, truncation=True,
                                          max_length=5120, padding=False,
                                          return_attention_mask=False)
            # Generate compressed outputs
            input_ids = torch.LongTensor([tokenized_text['input_ids']]).to(device)
            print("input_ids shape: ", input_ids.size())
            memory_slots = model._compress(input_ids)
            print("memory_slots shape: ", memory_slots.size())
            
            # prompt_output = model.tokenizer(data['prompt'], add_special_tokens=False, padding=False)
            prompt_ids = torch.LongTensor([[model.ae_token_id]]).to(device)
            print("prompt_ids shape: ", prompt_ids.size())

            prompt_answer_embs = model.tokens_to_embeddings(prompt_ids)
            print("prompt_answer_embs shape: ", prompt_answer_embs.size())

            memory_slots = memory_slots.to(prompt_answer_embs)
                        
            # Concatenate and clone input embeddings
            decoder_input_embeddings = torch.cat((memory_slots.unsqueeze(0), prompt_answer_embs), dim=1)
            print("decoder_input_embeddings shape: ", decoder_input_embeddings.size())

            output = decoder_input_embeddings.clone()
            print("output shape: ", output.size())

            generate_text = []
            past_key_values = None

            # Generate text output
            for i in range(512):
                with model.icae.disable_adapter():   # no independent decoder; use self.icae
                    out = model.icae(inputs_embeds=output, past_key_values=past_key_values, use_cache=True)
                logit = out.logits[:, -1, :model.vocab_size-1]
                past_key_values = out.past_key_values

                next_token_id = torch.argmax(logit, dim=-1)
                # print(next_token_id)
                
                if next_token_id.item() == 2:   # eos
                    break

                output = model.icae.get_base_model().model.embed_tokens(next_token_id).unsqueeze(1).to(device)
                generate_text.append(next_token_id.item())

            generated_text = model.tokenizer.decode(generate_text)
            outputs.append(generated_text)

            print("=========================== END ============================")

    return outputs, memory_slots

In [49]:
model = model.to("cuda")

In [51]:
lines = [
    "Four adults with 32 teeth went to the dentist for a checkup after realizing they were having severe tooth pain. They were found to have different numbers of damaged teeth, and each person had some teeth removed. The first person had 1/4 of all his teeth removed, and the second person had 3/8 of his teeth removed, the third person had half of his teeth removed, while the last person only had 4 teeth removed. What's the total number of teeth removed at the dental clinic?"
]

device = "cuda"
outputs, memory_slots = run_inference(model, lines)

Running inference


  0%|          | 0/1 [00:00<?, ?it/s]

Current line:  Four adults with 32 teeth went to the dentist for a checkup after realizing they were having severe tooth pain. They were found to have different numbers of damaged teeth, and each person had some teeth removed. The first person had 1/4 of all his teeth removed, and the second person had 3/8 of his teeth removed, the third person had half of his teeth removed, while the last person only had 4 teeth removed. What's the total number of teeth removed at the dental clinic?
input_ids shape:  torch.Size([1, 105])
memory_slots shape:  torch.Size([1, 2048])
prompt_ids shape:  torch.Size([1, 1])
prompt_answer_embs shape:  torch.Size([1, 1, 2048])
decoder_input_embeddings shape:  torch.Size([1, 2, 2048])
output shape:  torch.Size([1, 2, 2048])


100%|██████████| 1/1 [00:13<00:00, 13.53s/it]






In [52]:
memory_slots.shape

torch.Size([1, 2048])

In [53]:
memory_slots

tensor([[-0.3633, -1.6016,  0.3457,  ..., -0.5391,  1.1250, -0.9766]],
       device='cuda:0', dtype=torch.bfloat16)

In [54]:
import torch

def check_memory_slots(memory_slots):
    """
    Checks whether the memory slot embeddings contain meaningful values 
    or are close to zero.

    Args:
        memory_slots (torch.Tensor): The compressed memory representation from ICAE.
    """
    # Move to CPU for easier analysis
    memory_slots = memory_slots.detach().cpu()

    # Compute statistics
    mean_value = memory_slots.mean().item()
    std_value = memory_slots.std().item()
    min_value = memory_slots.min().item()
    max_value = memory_slots.max().item()

    print("\n📌 **Memory Slot Debugging** 📌")
    print(f"➡️ Mean Value: {mean_value:.6f}")
    print(f"➡️ Standard Deviation: {std_value:.6f}")
    print(f"➡️ Min Value: {min_value:.6f}")
    print(f"➡️ Max Value: {max_value:.6f}")

    # Check if most values are near zero
    zero_threshold = 1e-5
    near_zero_ratio = (torch.abs(memory_slots) < zero_threshold).float().mean().item()

    print(f"➡️ Percentage of Near-Zero Values: {near_zero_ratio * 100:.2f}%")
    
    if near_zero_ratio > 0.90:
        print("⚠️ WARNING: Memory embeddings are mostly zero! Encoder may not be learning meaningful compression.")
    else:
        print("✅ Memory embeddings contain significant nonzero values.")

    # Print a small sample of memory slot values
    print("\n🛠 Sample Memory Slot Values:")
    print(memory_slots[0, :10])  # Print first 10 values of the first memory slot

check_memory_slots(memory_slots)


📌 **Memory Slot Debugging** 📌
➡️ Mean Value: 0.023682
➡️ Standard Deviation: 2.296875
➡️ Min Value: -18.125000
➡️ Max Value: 11.187500
➡️ Percentage of Near-Zero Values: 0.00%
✅ Memory embeddings contain significant nonzero values.

🛠 Sample Memory Slot Values:
tensor([-0.3633, -1.6016,  0.3457,  1.7578, -0.7969,  4.3125,  1.3984,  1.4297,
         1.7969,  3.3438], dtype=torch.bfloat16)


In [55]:
ds = load_dataset("ankner/gsm8k-CoT")
train_dataset = ds["train"]
eval_dataset = ds["test"]

train_dataset = train_dataset.map(lambda example: {**example, "text": extract_reasoning_trace(example['response'])}).shuffle(seed=42)
eval_dataset = eval_dataset.map(lambda example: {**example, "text": extract_reasoning_trace(example['response'])}).shuffle(seed=42)
print("Dataset loaded successfully...")

Map:   0%|          | 0/7465 [00:00<?, ? examples/s]

Map:   0%|          | 0/1316 [00:00<?, ? examples/s]

Dataset loaded successfully...


In [56]:
train_dataset[0]

{'question': "Four adults with 32 teeth went to the dentist for a checkup after realizing they were having severe tooth pain. They were found to have different numbers of damaged teeth, and each person had some teeth removed. The first person had 1/4 of all his teeth removed, and the second person had 3/8 of his teeth removed, the third person had half of his teeth removed, while the last person only had 4 teeth removed. What's the total number of teeth removed at the dental clinic?",
 'answer': '40',
 'response': 'Each adult has 32 teeth initially.\n\nFor the first person, 1/4 of 32 teeth were removed: 32 × (1/4) = 8 teeth removed.\n\nFor the second person, 3/8 of 32 teeth were removed: 32 × (3/8) = 12 teeth removed.\n\nFor the third person, 1/2 of 32 teeth were removed: 32 × (1/2) = 16 teeth removed.\n\nThe fourth person had exactly 4 teeth removed.\n\nAdding all teeth removed: 8 + 12 + 16 + 4 = 40 teeth.\n\nTherefore, the final answer is 40.',
 'text': 'Each adult has 32 teeth initi

In [None]:
from torch.nn import CosineEmbeddingLoss

# Mean embedding of input
input_mean_embedding = segment_input_embedding.mean(dim=1).detach()

# Contrastive loss to ensure memory stores structured meaning
contrastive_loss_fct = CosineEmbeddingLoss(margin=0.5)
contrastive_target = torch.ones(memory_slots.shape[0]).to(memory_slots.device)
contrastive_loss = contrastive_loss_fct(memory_slots, input_mean_embedding, contrastive_target)

# Add contrastive loss to training loss
loss += 0.1 * contrastive_loss  
