In [1]:
import pickle
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import os

model_id = "answerdotai/ModernBERT-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
torch.set_float32_matmul_precision('high')

2025-08-31 13:27:27.961694: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-31 13:27:28.048147: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-08-31 13:27:28.048186: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-08-31 13:27:28.053458: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
def load_tokenized_chunk(save_dir, chunk_num):
    chunk_file = os.path.join(save_dir, f'chunk_{chunk_num:04d}.pkl')
    with open(chunk_file, 'rb') as f:
        chunk_data = pickle.load(f)

    current_chunk_size = chunk_data['chunk_size']
    
    buggy_tokenized = {
       'input_ids': chunk_data['input_ids'][:current_chunk_size],
       'attention_mask': chunk_data['attention_mask'][:current_chunk_size]
    }
    
    fixed_tokenized = {
       'input_ids': chunk_data['input_ids'][current_chunk_size:],
       'attention_mask': chunk_data['attention_mask'][current_chunk_size:]
    }

    return buggy_tokenized, fixed_tokenized

In [3]:
def get_mean_pooled_embeddings(input_ids, attention_mask):
    with torch.no_grad(), torch.autocast("cuda"):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        mask = attention_mask.unsqueeze(-1).expand_as(hidden).float()
        summed = (hidden * mask).sum(1)
        counts = mask.sum(1).clamp(min=1e-9)
        return summed / counts

In [4]:
def write_pair_file(chunk_num, buggy_arrays, fixed_arrays):
    os.makedirs(output_dir, exist_ok=True)
    buggy_embeddings = torch.cat(buggy_arrays, dim=0).cpu().numpy()
    fixed_embeddings = torch.cat(fixed_arrays, dim=0).cpu().numpy()
    output_file = os.path.join(output_dir, f'buggy_fixed_embeddings_chunk_{chunk_num:04d}.pkl')
    with open(output_file, 'wb') as f:
        pickle.dump(
            {
                'buggy_embeddings': buggy_embeddings,
                'fixed_embeddings': fixed_embeddings
            },
            f
        )

In [5]:
def embed_in_batches(encodings, batch_size=256):
    total = encodings['input_ids'].shape[0]
    encodings = {k: v.to(model.device, non_blocking=True) for k, v in encodings.items()}
    pooled_outputs = []

    for  start_idx in range(0, total, batch_size):
        end_idx = min(start_idx + batch_size, total)

        input_ids_batch = encodings["input_ids"][start_idx:end_idx]
        attention_mask_batch = encodings["attention_mask"][start_idx:end_idx]
        pooled = get_mean_pooled_embeddings(input_ids_batch, attention_mask_batch)
        pooled_outputs.append(pooled)

    all_embeddings = torch.cat(pooled_outputs)

    return all_embeddings.cpu()

In [None]:
tokenized_dir = "/mimer/NOBACKUP/groups/naiss2025-5-243/tokenized_chunks2" # Location of pretokenized data
output_dir = "/mimer/NOBACKUP/groups/naiss2025-5-243/buggy_fixed_embeddings" # Location of finished embeddings

step_size = 1000

for chunk_num in range(0, 23): # Alter if needed
    buggy_data, fixed_data = load_tokenized_chunk(tokenized_dir, chunk_num)
    total_size = buggy_data['input_ids'].shape[0]
    
    buggy_final = []
    fixed_final = []

    for batch_idx in range(20):
        print(f"Processing chunk {chunk_num}, batch {batch_idx}")
        start = batch_idx * step_size
        end = min(start + step_size, total_size)
        if start >= end:
            break

        buggy_batch = {
            'input_ids': buggy_data['input_ids'][start:end],
            'attention_mask': buggy_data['attention_mask'][start:end]
        }
        fixed_batch = {
            'input_ids': fixed_data['input_ids'][start:end],
            'attention_mask': fixed_data['attention_mask'][start:end]
        }

        buggy_emb = embed_in_batches(buggy_batch)
        fixed_emb = embed_in_batches(fixed_batch)

        buggy_final.append(buggy_emb)
        fixed_final.append(fixed_emb)
    print("Chunk completed")
    write_pair_file(chunk_num, buggy_final, fixed_final)
    torch.cuda.empty_cache()

Processing chunk 19, batch 0
Processing chunk 19, batch 1
Processing chunk 19, batch 2
Processing chunk 19, batch 3
Processing chunk 19, batch 4
Processing chunk 19, batch 5
Processing chunk 19, batch 6


In [16]:
torch.cuda.empty_cache()