In [1]:
import torch
import pickle
import gzip
import os
from torch.utils.data import DataLoader
from datetime import datetime
import gc
from typing import Iterator

import sys
sys.path.append("../")

from shared_utils.data import CSVPromptDataset
from shared_utils.load import get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.util import get_model

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

In [10]:
#num_exit_samples = 4
device = "cuda"
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model_config_path = "../config_deepseek.yaml"
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data_deduplicated.csv"
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"
batch_size = 1
chunk_size = 100 #for saving data

In [3]:
#output_dir = "teacher_generated_data"
output_dir = "/workspace/data/teacher_generated_data_gzip"
os.makedirs(output_dir, exist_ok=True)

In [4]:
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
model = get_model(model_name, config['model'], device)

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/679 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.55G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

In [5]:
dataset = CSVPromptDataset(dataset_path, prompt_config_path)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=False)

In [6]:
model = replace_attention_layers(model, config['lora'], device)
model.eval() #not training

replacing layer model.layers.0
replacing layer model.layers.1
replacing layer model.layers.2
replacing layer model.layers.3
replacing layer model.layers.4
replacing layer model.layers.5
replacing layer model.layers.6
replacing layer model.layers.7
replacing layer model.layers.8
replacing layer model.layers.9
replacing layer model.layers.10
replacing layer model.layers.11
replacing layer model.layers.12
replacing layer model.layers.13
replacing layer model.layers.14
replacing layer model.layers.15
replacing layer model.layers.16
replacing layer model.layers.17
replacing layer model.layers.18
replacing layer model.layers.19
replacing layer model.layers.20
replacing layer model.layers.21
replacing layer model.layers.22
replacing layer model.layers.23
replacing layer model.layers.24
replacing layer model.layers.25
replacing layer model.layers.26
replacing layer model.layers.27
address this hack!
trainable params: 2,179,072 || all params: 1,779,310,108 || trainable%: 0.1225


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): DynamicallyTypedModelWithReadout(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 1536)
        (layers): ModuleList(
          (0-27): 28 x DynamicallyTypedLayerWithExit(
            (self_attn): Qwen2Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=1536, bias=True)
                (lora_dropout): ModuleDict(
                  (early_exiter): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (early_exiter): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (early_exiter): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
           

In [7]:
all_teacher_data = []

In [8]:
current_chunk_data = []
chunk_idx = 0
total_samples_processed = 0

In [9]:
metadata = {
    'model_name': model_name,
    'dataset_path': dataset_path,
    'prompt_config_path': prompt_config_path,
    'config': config,
    'chunk_size': chunk_size,
    'timestamp': datetime.now().isoformat(),
    'system_prompt': dataset.system_prompt,
    'prefiller': dataset.prefiller
}
metadata_path = os.path.join(output_dir, "metadata.pkl.gz")
with gzip.open(metadata_path, 'wb', compresslevel=9) as f:
    pickle.dump(metadata, f, protocol=pickle.HIGHEST_PROTOCOL)

In [11]:
def save_chunk(chunk_data, chunk_index, output_directory):
    """Save a chunk of data to disk with gzip compression"""
    chunk_filename = os.path.join(output_directory, f"chunk_{chunk_index:04d}.pkl.gz")

    #convert all float32 tensors to float16 before saving for size
    compressed_data = []
    for sample in chunk_data:
        compressed_sample = {}
        for key, value in sample.items():
            if isinstance(value, torch.Tensor) and value.dtype == torch.float32:
                if key == 'sft_teacher_final_layer_logprobs':
                    compressed_sample[key] = (value * (value > -14.0)).to_sparse().half() #switched to a sparse tensor and converts to float16
                else:
                    compressed_sample[key] = value.half() #converts to float16
            else:
                compressed_sample[key] = value
        compressed_data.append(compressed_sample)
    
    with gzip.open(chunk_filename, 'wb', compresslevel=9) as f: #play with compresslevel, 6 is moderate
        pickle.dump({
            'chunk_idx': chunk_index,
            'data': compressed_data,
            'num_samples': len(chunk_data)
        }, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    file_size_mb = os.path.getsize(chunk_filename) / (1024 * 1024)
    print(f"Saved chunk {chunk_index} with {len(chunk_data)} samples")
    print(f"  File: {chunk_filename}")
    print(f"  Size: {file_size_mb:.2f} MB")

In [12]:
for batch_idx, prompt_batch in enumerate(dataloader):
    # Remove the testing limit if you want to process all data
    #if total_samples_processed >= 30:
    #    break
    
    if total_samples_processed % 50 == 0:
        print(f"Processing batch {total_samples_processed + 1}/{len(dataloader)} (Total samples: {total_samples_processed})")
    
    with torch.no_grad():
        # Generate SFT targets
        set_transformer_early_exit_mode(model, 'sft_teacher')
        sft_teacher_response, (sft_teacher_generated_tokens, sft_teacher_final_layer_logprobs, gathered_early_exit_hidden_states) = \
            generate_text(
                model=model, 
                prompt=prompt_batch.full_user_prompt, 
                system_prompt=dataset.system_prompt, 
                prefiller=dataset.prefiller, 
                tokenizer=tokenizer, 
                generation_config=config['generation'], 
                device=device
            )
        
        # Compute early exit probabilities
        early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)
        
        # KL divergence calculations
        teacher_expanded = sft_teacher_final_layer_logprobs.unsqueeze(1).exp()
        early_output_probs = early_output_log_probs.exp()
        eps = 1e-16
        kl_div1 = - (teacher_expanded * (early_output_probs + eps).log()).sum(-1)
        kl_div2 = (teacher_expanded * ((teacher_expanded + eps) / (early_output_probs + eps)).log()).sum(-1)
        
        # Store data for this batch
        batch_data = {
            'batch_idx': batch_idx, #in case shuffled, but don't think matters really
            'prompt_idx': prompt_batch.idx[0] if hasattr(prompt_batch, 'idx') else batch_idx,
            'full_user_prompt': prompt_batch.full_user_prompt,
            'sft_teacher_response': sft_teacher_response, #full generated response text
            'sft_teacher_generated_tokens': sft_teacher_generated_tokens.cpu(), #token IDs [batch, full_length]
            'sft_teacher_final_layer_logprobs': sft_teacher_final_layer_logprobs.cpu(), #final layer logprobs [batch, gen_len, vocab]
            'kl_div1_per_layer': kl_div1.cpu(), #cross-entropy KL divergence per layer to final [batch, num_layers, gen_len]
            'kl_div2_per_layer': kl_div2.cpu(), #standard KL divergence per layer to final [batch, num_layers, gen_len]
            'exitable_layer_idxs': model.exitable_layer_idxs.cpu(),
        }
        
        current_chunk_data.append(batch_data)
        total_samples_processed += 1
        
        # Save chunk if we've reached chunk_size
        if len(current_chunk_data) >= chunk_size:
            save_chunk(current_chunk_data, chunk_idx, output_dir)
            current_chunk_data = []
            chunk_idx += 1
        
        # Clear GPU memory
        torch.cuda.empty_cache()

Processing batch 1/619 (Total samples: 0)
Processing batch 51/619 (Total samples: 50)
Saved chunk 0 with 100 samples
  File: /workspace/data/teacher_generated_data_gzip/chunk_0000.pkl.gz
  Size: 83.52 MB
Processing batch 101/619 (Total samples: 100)
Processing batch 151/619 (Total samples: 150)
Saved chunk 1 with 100 samples
  File: /workspace/data/teacher_generated_data_gzip/chunk_0001.pkl.gz
  Size: 76.17 MB
Processing batch 201/619 (Total samples: 200)
Processing batch 251/619 (Total samples: 250)
Saved chunk 2 with 100 samples
  File: /workspace/data/teacher_generated_data_gzip/chunk_0002.pkl.gz
  Size: 79.44 MB
Processing batch 301/619 (Total samples: 300)
Processing batch 351/619 (Total samples: 350)
Saved chunk 3 with 100 samples
  File: /workspace/data/teacher_generated_data_gzip/chunk_0003.pkl.gz
  Size: 83.44 MB
Processing batch 401/619 (Total samples: 400)
Processing batch 451/619 (Total samples: 450)
Saved chunk 4 with 100 samples
  File: /workspace/data/teacher_generated_d

In [13]:
if current_chunk_data:
    print(f"\nSaving final chunk with {len(current_chunk_data)} samples...")
    save_chunk(current_chunk_data, chunk_idx, output_dir)


Saving final chunk with 19 samples...
Saved chunk 6 with 19 samples
  File: /workspace/data/teacher_generated_data_gzip/chunk_0006.pkl.gz
  Size: 11.80 MB


In [14]:
def merge_teacher_data_chunks(
    output_dir: str,
    merged_filename: str = "merged_teacher_data_sparse.pkl.gz",
    delete_chunks: bool = True,
    compresslevel: int = 9
):
    print(f"Merging chunks from {output_dir}...")

    #load metadata - now gzipped
    meta_path = os.path.join(output_dir, "metadata.pkl.gz")
    with gzip.open(meta_path, "rb") as f:
        metadata = pickle.load(f)

    chunk_files = sorted(
        f for f in os.listdir(output_dir)
        if f.startswith("chunk_") and f.endswith(".pkl.gz")
    )
    print(f"Found {len(chunk_files)} chunk files")

    merged_path = os.path.join(output_dir, merged_filename)
    total = 0

    with gzip.open(merged_path, "wb", compresslevel=compresslevel) as fout:
        # write header once
        pickle.dump({'metadata': metadata}, fout, protocol=pickle.HIGHEST_PROTOCOL)

        for i, cf in enumerate(chunk_files, 1):
            cpath = os.path.join(output_dir, cf)
            print(f"  [{i}/{len(chunk_files)}] {cf}")
            with gzip.open(cpath, "rb") as fin:
                chunk = pickle.load(fin)          # {'chunk_idx','data',...}
                for sample in chunk['data']:
                    pickle.dump(sample, fout, protocol=pickle.HIGHEST_PROTOCOL)
                total += len(chunk['data'])

            del chunk
            gc.collect()

            if delete_chunks:
                os.remove(cpath)
                print(f"    deleted {cf}")

        pickle.dump({'_end': True, 'num_samples': total}, fout, protocol=pickle.HIGHEST_PROTOCOL)

    size_mb = os.path.getsize(merged_path) / (1024 * 1024)
    print(f"\nTotal samples: {total}")
    print(f"Merged file size: {size_mb:.2f} MB")
    print(f"Saved to: {merged_path}")


def iter_merged_teacher_data(merged_path: str) -> Iterator[dict]:
    """
    Lazily iterate samples from the merged stream.
    Skips header and footer.
    """
    with gzip.open(merged_path, "rb") as f:
        header = pickle.load(f)  # {'metadata': ...}
        while True:
            try:
                obj = pickle.load(f)
            except EOFError:
                break
            if isinstance(obj, dict) and obj.get('_end'):
                break
            yield obj


In [9]:
def merge_teacher_data_chunks(output_dir, merged_filename="merged_teacher_data.pkl.gz"):
    """
    Merge all chunk files into a single file
    """
    print(f"Merging chunks from {output_dir}...")
    
    #load metadata - now gzipped
    metadata_path = os.path.join(output_dir, "metadata.pkl.gz")
    with gzip.open(metadata_path, 'rb') as f:
        metadata = pickle.load(f)
    
    #find all chunk files - now with .gz extension
    chunk_files = sorted([f for f in os.listdir(output_dir) if f.startswith("chunk_") and f.endswith(".pkl.gz")])
    print(f"Found {len(chunk_files)} chunk files to merge")
    
    all_data = []
    for chunk_file in chunk_files:
        chunk_path = os.path.join(output_dir, chunk_file)
        print(f"  Loading {chunk_file}...")
        with gzip.open(chunk_path, 'rb') as f:
            chunk_data = pickle.load(f)
            all_data.extend(chunk_data['data'])
            print(f"    Loaded {len(chunk_data['data'])} samples")
    
    #save merged data with gzip
    merged_path = os.path.join(output_dir, merged_filename)
    print(f"\nSaving merged data to {merged_path}...")
    with gzip.open(merged_path, 'wb', compresslevel=9) as f:
        pickle.dump({
            'teacher_data': all_data,
            'metadata': {
                **metadata,
                'num_samples': len(all_data),
                'num_chunks_merged': len(chunk_files)
            }
        }, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    file_size_mb = os.path.getsize(merged_path) / (1024 * 1024)
    print(f"Total samples: {len(all_data)}")
    print(f"Merged file size: {file_size_mb:.2f} MB")
    print(f"Saved to: {merged_path}")
    
    #return all_data

In [26]:
def merge_teacher_data_chunks(
    output_dir: str,
    merged_filename: str = "merged_teacher_data_sparse_combtensor.pkl.gz",
    delete_chunks: bool = True,
    compresslevel: int = 9
):
    print(f"Merging chunks from {output_dir}...")
    
    # Load metadata
    meta_path = os.path.join(output_dir, "metadata.pkl.gz")
    with gzip.open(meta_path, "rb") as f:
        metadata = pickle.load(f)
    
    chunk_files = sorted(
        f for f in os.listdir(output_dir)
        if f.startswith("chunk_") and f.endswith(".pkl.gz")
    )
    print(f"Found {len(chunk_files)} chunk files")
    
    # First pass: collect all data and find max sequence length
    all_samples = []
    max_seq_len = 0
    
    print("First pass: collecting samples and finding max sequence length...")
    for i, cf in enumerate(chunk_files, 1):
        cpath = os.path.join(output_dir, cf)
        print(f"  Reading [{i}/{len(chunk_files)}] {cf}")
        with gzip.open(cpath, "rb") as fin:
            chunk = pickle.load(fin)
            for sample in chunk['data']:
                all_samples.append(sample)
                # Check sequence length
                if 'sft_teacher_final_layer_logprobs' in sample:
                    seq_len = sample['sft_teacher_final_layer_logprobs'].shape[1]
                    max_seq_len = max(max_seq_len, seq_len)
    
    total_samples = len(all_samples)
    vocab_size = 151936  # From your data
    print(f"\nTotal samples: {total_samples}")
    print(f"Max sequence length: {max_seq_len}")
    print(f"Vocabulary size: {vocab_size}")
    
    # Create unified data structure
    print("\nCreating unified data structure...")
    unified_data = {
        'metadata': metadata,
        'num_samples': total_samples,
        'max_seq_len': max_seq_len,
        'vocab_size': vocab_size,
        # Store all logprobs as a list of sparse tensors with their metadata
        'all_sparse_logprobs': [],
        'logprobs_metadata': {
            'sample_indices': [],      # Which sample each tensor belongs to
            'sequence_lengths': [],    # Actual length of each sequence
            'start_positions': [],     # Where each sample starts in a hypothetical stacked tensor
        },
        # Store other data in parallel arrays
        'batch_indices': [],
        'prompt_indices': [],
        'full_user_prompts': [],
        'sft_teacher_responses': [],
        'sft_teacher_generated_tokens': [],
        'kl_div1_per_layer': [],
        'kl_div2_per_layer': [],
        'exitable_layer_idxs': None,  # Same for all samples
    }
    
    # Second pass: populate unified structure
    print("Second pass: creating unified structure...")
    current_position = 0
    
    for idx, sample in enumerate(all_samples):
        if idx % 100 == 0:
            print(f"  Processing sample {idx}/{total_samples}")
        
        # Store logprobs with metadata
        if 'sft_teacher_final_layer_logprobs' in sample:
            sparse_logprobs = sample['sft_teacher_final_layer_logprobs']
            seq_len = sparse_logprobs.shape[1]
            
            unified_data['all_sparse_logprobs'].append(sparse_logprobs)
            unified_data['logprobs_metadata']['sample_indices'].append(idx)
            unified_data['logprobs_metadata']['sequence_lengths'].append(seq_len)
            unified_data['logprobs_metadata']['start_positions'].append(current_position)
            current_position += seq_len
        
        # Store other data
        unified_data['batch_indices'].append(sample.get('batch_idx'))
        unified_data['prompt_indices'].append(sample.get('prompt_idx'))
        unified_data['full_user_prompts'].append(sample.get('full_user_prompt'))
        unified_data['sft_teacher_responses'].append(sample.get('sft_teacher_response'))
        unified_data['sft_teacher_generated_tokens'].append(sample.get('sft_teacher_generated_tokens'))
        unified_data['kl_div1_per_layer'].append(sample.get('kl_div1_per_layer'))
        unified_data['kl_div2_per_layer'].append(sample.get('kl_div2_per_layer'))
        
        # Set exitable_layer_idxs (same for all samples)
        if unified_data['exitable_layer_idxs'] is None and 'exitable_layer_idxs' in sample:
            unified_data['exitable_layer_idxs'] = sample['exitable_layer_idxs']
    
    # Convert metadata lists to tensors for efficiency
    unified_data['logprobs_metadata']['sample_indices'] = torch.tensor(
        unified_data['logprobs_metadata']['sample_indices'], dtype=torch.long
    )
    unified_data['logprobs_metadata']['sequence_lengths'] = torch.tensor(
        unified_data['logprobs_metadata']['sequence_lengths'], dtype=torch.long
    )
    unified_data['logprobs_metadata']['start_positions'] = torch.tensor(
        unified_data['logprobs_metadata']['start_positions'], dtype=torch.long
    )
    
    # Save unified structure
    merged_path = os.path.join(output_dir, merged_filename)
    print(f"\nSaving unified structure to {merged_path}...")
    with gzip.open(merged_path, "wb", compresslevel=compresslevel) as fout:
        pickle.dump(unified_data, fout, protocol=pickle.HIGHEST_PROTOCOL)
    
    # Clean up
    del all_samples
    gc.collect()
    
    # Delete chunks if requested
    if delete_chunks:
        for cf in chunk_files:
            cpath = os.path.join(output_dir, cf)
            os.remove(cpath)
            print(f"  Deleted {cf}")
    
    size_mb = os.path.getsize(merged_path) / (1024 * 1024)
    print(f"\nTotal samples: {total_samples}")
    print(f"Merged file size: {size_mb:.2f} MB")
    print(f"Saved to: {merged_path}")

def load_merged_teacher_data(merged_path: str):
    """Load the entire unified dataset"""
    with gzip.open(merged_path, "rb") as f:
        return pickle.load(f)

def get_sample_from_unified(unified_data: dict, sample_idx: int) -> dict:
    """
    Reconstruct a single sample from the unified structure.
    This shows how to access data during training.
    """
    # Find the logprobs for this sample
    logprobs_idx = (unified_data['logprobs_metadata']['sample_indices'] == sample_idx).nonzero()[0].item()
    sparse_logprobs = unified_data['all_sparse_logprobs'][logprobs_idx]
    
    # Reconstruct the sample
    sample = {
        'batch_idx': unified_data['batch_indices'][sample_idx],
        'prompt_idx': unified_data['prompt_indices'][sample_idx],
        'full_user_prompt': unified_data['full_user_prompts'][sample_idx],
        'sft_teacher_response': unified_data['sft_teacher_responses'][sample_idx],
        'sft_teacher_generated_tokens': unified_data['sft_teacher_generated_tokens'][sample_idx],
        'sft_teacher_final_layer_logprobs': sparse_logprobs,
        'kl_div1_per_layer': unified_data['kl_div1_per_layer'][sample_idx],
        'kl_div2_per_layer': unified_data['kl_div2_per_layer'][sample_idx],
        'exitable_layer_idxs': unified_data['exitable_layer_idxs'],
    }
    return sample

In [15]:
merge_teacher_data_chunks(output_dir)

Merging chunks from /workspace/data/teacher_generated_data_gzip...
Found 7 chunk files
  [1/7] chunk_0000.pkl.gz
    deleted chunk_0000.pkl.gz
  [2/7] chunk_0001.pkl.gz
    deleted chunk_0001.pkl.gz
  [3/7] chunk_0002.pkl.gz
    deleted chunk_0002.pkl.gz
  [4/7] chunk_0003.pkl.gz
    deleted chunk_0003.pkl.gz
  [5/7] chunk_0004.pkl.gz
    deleted chunk_0004.pkl.gz
  [6/7] chunk_0005.pkl.gz
    deleted chunk_0005.pkl.gz
  [7/7] chunk_0006.pkl.gz
    deleted chunk_0006.pkl.gz

Total samples: 619
Merged file size: 472.56 MB
Saved to: /workspace/data/teacher_generated_data_gzip/merged_teacher_data_sparse.pkl.gz


In [28]:
os.path.getsize(os.path.join(output_dir, "merged_teacher_data_sparse_combtensor.pkl.gz")) / (1024 * 1024)

26.803590774536133

In [29]:
os.path.getsize(os.path.join(output_dir, "merged_teacher_data_sparse.pkl.gz")) / (1024 * 1024)

26.832829475402832

In [30]:
os.path.getsize(os.path.join(output_dir, "merged_teacher_data_original.pkl.gz")) / (1024 * 1024)

2120.7067728042603