In [None]:
import torch
import pickle
import gzip
import os
from torch.utils.data import DataLoader
from datetime import datetime

import sys
sys.path.append("../")

from shared_utils.data import CSVPromptDataset
from shared_utils.load import get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.util import get_model

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

In [2]:
#num_exit_samples = 4
device = "cuda"
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model_config_path = "../config_deepseek.yaml"
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data.csv"
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"
batch_size = 1
chunk_size = 100 #for saving data

In [3]:
#output_dir = "teacher_generated_data"
output_dir = "/workspace/data/teacher_generated_data_gzip"
os.makedirs(output_dir, exist_ok=True)

In [4]:
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
model = get_model(model_name, config['model'], device)

In [5]:
dataset = CSVPromptDataset(dataset_path, prompt_config_path)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=False)

In [6]:
model = replace_attention_layers(model, config['lora'], device)
model.eval() #not training

replacing layer model.layers.0
replacing layer model.layers.1
replacing layer model.layers.2
replacing layer model.layers.3
replacing layer model.layers.4
replacing layer model.layers.5
replacing layer model.layers.6
replacing layer model.layers.7
replacing layer model.layers.8
replacing layer model.layers.9
replacing layer model.layers.10
replacing layer model.layers.11
replacing layer model.layers.12
replacing layer model.layers.13
replacing layer model.layers.14
replacing layer model.layers.15
replacing layer model.layers.16
replacing layer model.layers.17
replacing layer model.layers.18
replacing layer model.layers.19
replacing layer model.layers.20
replacing layer model.layers.21
replacing layer model.layers.22
replacing layer model.layers.23
replacing layer model.layers.24
replacing layer model.layers.25
replacing layer model.layers.26
replacing layer model.layers.27
address this hack!
trainable params: 2,179,072 || all params: 1,779,310,108 || trainable%: 0.1225


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): DynamicallyTypedModelWithReadout(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 1536)
        (layers): ModuleList(
          (0-27): 28 x DynamicallyTypedLayerWithExit(
            (self_attn): Qwen2Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=1536, bias=True)
                (lora_dropout): ModuleDict(
                  (early_exiter): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (early_exiter): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (early_exiter): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
           

In [12]:
all_teacher_data = []

In [13]:
current_chunk_data = []
chunk_idx = 0
total_samples_processed = 0

In [9]:
metadata = {
    'model_name': model_name,
    'dataset_path': dataset_path,
    'prompt_config_path': prompt_config_path,
    'config': config,
    'chunk_size': chunk_size,
    'timestamp': datetime.now().isoformat(),
    'system_prompt': dataset.system_prompt,
    'prefiller': dataset.prefiller
}
metadata_path = os.path.join(output_dir, "metadata.pkl.gz")
with gzip.open(metadata_path, 'wb', compresslevel=6) as f:
    pickle.dump(metadata, f, protocol=pickle.HIGHEST_PROTOCOL)

In [14]:
def save_chunk(chunk_data, chunk_index, output_directory):
    """Save a chunk of data to disk with gzip compression"""
    chunk_filename = os.path.join(output_directory, f"chunk_{chunk_index:04d}.pkl.gz")

    # Convert all float32 tensors to float16 before saving
    compressed_data = []
    for sample in chunk_data:
        compressed_sample = {}
        for key, value in sample.items():
            if isinstance(value, torch.Tensor) and value.dtype == torch.float32:
                compressed_sample[key] = value.half()  # Convert to float16
            else:
                compressed_sample[key] = value
        compressed_data.append(compressed_sample)
    
    with gzip.open(chunk_filename, 'wb', compresslevel=9) as f: #play with compresslevel, 6 is moderate
        pickle.dump({
            'chunk_idx': chunk_index,
            'data': chunk_data,
            'num_samples': len(chunk_data)
        }, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    file_size_mb = os.path.getsize(chunk_filename) / (1024 * 1024)
    print(f"Saved chunk {chunk_index} with {len(chunk_data)} samples")
    print(f"  File: {chunk_filename}")
    print(f"  Size: {file_size_mb:.2f} MB")

In [None]:
for batch_idx, prompt_batch in enumerate(dataloader):
    # Remove the testing limit if you want to process all data
    #if total_samples_processed >= 30:
    #    break
    
    if total_samples_processed % 50 == 0:
        print(f"Processing batch {total_samples_processed + 1}/{len(dataloader)} (Total samples: {total_samples_processed})")
    
    with torch.no_grad():
        # Generate SFT targets
        set_transformer_early_exit_mode(model, 'sft_teacher')
        sft_teacher_response, (sft_teacher_generated_tokens, sft_teacher_final_layer_logprobs, gathered_early_exit_hidden_states) = \
            generate_text(
                model=model, 
                prompt=prompt_batch.full_user_prompt, 
                system_prompt=dataset.system_prompt, 
                prefiller=dataset.prefiller, 
                tokenizer=tokenizer, 
                generation_config=config['generation'], 
                device=device
            )
        
        # Compute early exit probabilities
        early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)
        
        # KL divergence calculations
        teacher_expanded = sft_teacher_final_layer_logprobs.unsqueeze(1).exp()
        early_output_probs = early_output_log_probs.exp()
        eps = 1e-16
        kl_div1 = - (teacher_expanded * (early_output_probs + eps).log()).sum(-1)
        kl_div2 = (teacher_expanded * ((teacher_expanded + eps) / (early_output_probs + eps)).log()).sum(-1)
        
        # Store data for this batch
        batch_data = {
            'batch_idx': batch_idx, #in case shuffled, but don't think matters really
            'prompt_idx': prompt_batch.idx[0] if hasattr(prompt_batch, 'idx') else batch_idx,
            'full_user_prompt': prompt_batch.full_user_prompt,
            'sft_teacher_response': sft_teacher_response, #full generated response text
            'sft_teacher_generated_tokens': sft_teacher_generated_tokens.cpu(), #token IDs [batch, full_length]
            'sft_teacher_final_layer_logprobs': sft_teacher_final_layer_logprobs.cpu(), #final layer logprobs [batch, gen_len, vocab]
            'kl_div1_per_layer': kl_div1.cpu(), #cross-entropy KL divergence per layer to final [batch, num_layers, gen_len]
            'kl_div2_per_layer': kl_div2.cpu(), #standard KL divergence per layer to final [batch, num_layers, gen_len]
            'exitable_layer_idxs': model.exitable_layer_idxs.cpu(),
        }
        
        current_chunk_data.append(batch_data)
        total_samples_processed += 1
        
        # Save chunk if we've reached chunk_size
        if len(current_chunk_data) >= chunk_size:
            save_chunk(current_chunk_data, chunk_idx, output_dir)
            current_chunk_data = []
            chunk_idx += 1
        
        # Clear GPU memory
        torch.cuda.empty_cache()

Processing batch 1/13309 (Total samples: 0)


In [12]:
if current_chunk_data:
    print(f"\nSaving final chunk with {len(current_chunk_data)} samples...")
    save_chunk(current_chunk_data, chunk_idx, output_dir)

In [14]:
def merge_teacher_data_chunks(output_dir, merged_filename="merged_teacher_data.pkl.gz"):
    """
    Merge all chunk files into a single file
    """
    print(f"Merging chunks from {output_dir}...")
    
    #load metadata - now gzipped
    metadata_path = os.path.join(output_dir, "metadata.pkl.gz")
    with gzip.open(metadata_path, 'rb') as f:
        metadata = pickle.load(f)
    
    #find all chunk files - now with .gz extension
    chunk_files = sorted([f for f in os.listdir(output_dir) if f.startswith("chunk_") and f.endswith(".pkl.gz")])
    print(f"Found {len(chunk_files)} chunk files to merge")
    
    all_data = []
    for chunk_file in chunk_files:
        chunk_path = os.path.join(output_dir, chunk_file)
        print(f"  Loading {chunk_file}...")
        with gzip.open(chunk_path, 'rb') as f:
            chunk_data = pickle.load(f)
            all_data.extend(chunk_data['data'])
            print(f"    Loaded {len(chunk_data['data'])} samples")
    
    #save merged data with gzip
    merged_path = os.path.join(output_dir, merged_filename)
    print(f"\nSaving merged data to {merged_path}...")
    with gzip.open(merged_path, 'wb', compresslevel=6) as f:
        pickle.dump({
            'teacher_data': all_data,
            'metadata': {
                **metadata,
                'num_samples': len(all_data),
                'num_chunks_merged': len(chunk_files)
            }
        }, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    file_size_mb = os.path.getsize(merged_path) / (1024 * 1024)
    print(f"Total samples: {len(all_data)}")
    print(f"Merged file size: {file_size_mb:.2f} MB")
    print(f"Saved to: {merged_path}")
    
    #return all_data

In [15]:
merge_teacher_data_chunks(output_dir)

Merging chunks from /workspace/data/teacher_generated_data_gzip_test...
Found 3 chunk files to merge
  Loading chunk_0000.pkl.gz...
    Loaded 10 samples
  Loading chunk_0001.pkl.gz...
    Loaded 10 samples
  Loading chunk_0002.pkl.gz...
    Loaded 10 samples

Saving merged data to /workspace/data/teacher_generated_data_gzip_test/merged_teacher_data.pkl.gz...
Total samples: 30
Merged file size: 4672.37 MB
Saved to: /workspace/data/teacher_generated_data_gzip_test/merged_teacher_data.pkl.gz
