In [1]:
import json
import pickle
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from datasets import Dataset
from datasets import load_dataset
import os
import math

In [None]:
model_id = "answerdotai/ModernBERT-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
torch.set_float32_matmul_precision('high')

dataset = load_dataset(
    "NicholasOgenstad/my-runbugrun-dataset",
    data_files="runbugrun_all_pairs_with_language.json",
    split="train"
)
dataset = dataset.filter(lambda example: example["language"] != "tests")

buggy = dataset['buggy_code']
fixed = dataset['fixed_code']

In [None]:
def load_tokenized_chunk(save_dir, chunk_num):
    chunk_file = os.path.join(save_dir, f'chunk_{chunk_num:04d}.pkl')
    with open(chunk_file, 'rb') as f:
        chunk_data = pickle.load(f)

    current_chunk_size = chunk_data['chunk_size']
    
    buggy_tokenized = {
       'input_ids': chunk_data['input_ids'][:current_chunk_size],
       'attention_mask': chunk_data['attention_mask'][:current_chunk_size]
    }
    
    fixed_tokenized = {
       'input_ids': chunk_data['input_ids'][current_chunk_size:],
       'attention_mask': chunk_data['attention_mask'][current_chunk_size:]
    }

    return buggy_tokenized, fixed_tokenized

In [None]:
def get_mean_pooled_embeddings(input_ids, attention_mask):
    with torch.no_grad(), torch.autocast("cuda"):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        mask = attention_mask.unsqueeze(-1).expand_as(hidden).float()
        summed = (hidden * mask).sum(1)
        counts = mask.sum(1).clamp(min=1e-9)
        return summed / counts

In [None]:
def write_diff_file(chunk_num, diff_array):
    os.makedirs(output_dir, exist_ok=True)
    
    diff_embeddings = torch.cat(diff_array, dim=0)
    diff_embeddings_np = diff_embeddings.cpu().numpy()
    
    output_file = os.path.join(output_dir, f'diff_embeddings_chunk_{chunk_num:04d}.pkl')
    with open(output_file, 'wb') as f:
        pickle.dump(diff_embeddings_np, f)

In [None]:
def embed_in_batches(encodings, batch_size=192):
    total = encodings['input_ids'].shape[0]
    if total == 0:  
        return torch.empty(0, model.config.hidden_size).cpu()
        
    encodings = {k: v.to(model.device, non_blocking=True) for k, v in encodings.items()}
    pooled_outputs = []
    
    for  start_idx in range(0, total, batch_size):
        end_idx = min(start_idx + batch_size, total)
        input_ids_batch = encodings["input_ids"][start_idx:end_idx]
        attention_mask_batch = encodings["attention_mask"][start_idx:end_idx]
        pooled = get_mean_pooled_embeddings(input_ids_batch, attention_mask_batch)
        pooled_outputs.append(pooled)
        
    all_embeddings = torch.cat(pooled_outputs)
    return all_embeddings.cpu()

In [None]:
tokenized_dir = "/mimer/NOBACKUP/groups/naiss2025-5-243/tokenized_chunks2" # Pretokenized data location
output_dir = "/mimer/NOBACKUP/groups/naiss2025-5-243/diff_embeddings2" # Location of embeddings after encoding

step_size = 1000

for chunk_num in range(22, 23):
    
    buggy_data, fixed_data = load_tokenized_chunk(tokenized_dir, chunk_num)
    total_size = buggy_data['input_ids'].shape[0]
    
    diff_final = []
    batch_idx = 0
    while True:
        print(f"Processing chunk {chunk_num}, batch {batch_idx}")
        start = batch_idx * step_size
        end = min(start + step_size, total_size)
        
        if start >= total_size:
            break
        
        buggy_batch = {
            'input_ids': buggy_data['input_ids'][start:end],
            'attention_mask': buggy_data['attention_mask'][start:end]
        }
        fixed_batch = {
            'input_ids': fixed_data['input_ids'][start:end],
            'attention_mask': fixed_data['attention_mask'][start:end]
        }
    
        buggy = embed_in_batches(buggy_batch)
        fixed = embed_in_batches(fixed_batch)
    
        diff_batch = fixed - buggy
        diff_final.append(diff_batch)
        batch_idx += 1
    write_diff_file(chunk_num, diff_final)
    torch.cuda.empty_cache()