In [1]:
import h5py
import torch
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
from tqdm import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
H5_PATH = os.path.join(Path.cwd().parent.parent, 'data', 'stackexchange_embeddings_tokenized.h5')
CSV_PATH = os.path.join(Path.cwd().parent.parent, 'data', 'stackexchange_dataset.csv')
MODEL_NAME = 'Qwen/Qwen3-Embedding-8B'
MAX_LEN_BODY = 32
MAX_LEN_TITLE = 4

In [3]:
def manual_mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state # Keep in float32! T_T
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

In [None]:
def repair():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    
    df = pd.read_csv(CSV_PATH, index_col='question_id')
    df = df[~df.index.duplicated()]
    
    with h5py.File(H5_PATH, 'r+') as f:
        body_dset = f['body_seq']
        title_dset = f['title_emb']
        
        bad_indices = []
        num_samples = body_dset.shape[0]
        chunk_size = 1000
        
        for i in tqdm(range(0, num_samples, chunk_size), desc="Scanning"):
            end = min(i + chunk_size, num_samples)
            
            b_chunk = body_dset[i:end]
            t_chunk = title_dset[i:end]
            
            bad_mask = ~np.isfinite(b_chunk).all(axis=(1,2)) 
            bad_mask_t = ~np.isfinite(t_chunk).all(axis=1)
            
            local_bad = np.where(bad_mask | bad_mask_t)[0]
            bad_indices.extend((local_bad + i).tolist())

        print(f"Found {len(bad_indices)} corrupted samples.")
        
        if len(bad_indices) == 0:
            print("No repairs needed!")
            return

        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        model = AutoModel.from_pretrained(
            MODEL_NAME, 
            trust_remote_code=True,
            device_map="auto",          
            torch_dtype=torch.float32   
        )
        model.eval()

        for idx in tqdm(bad_indices, desc="Fixing"):
            row = df.iloc[idx]
            text_body = row['question_text']
            text_title = row['title']

            inputs_body = tokenizer(
                [text_body], return_tensors="pt", padding="max_length", 
                truncation=True, max_length=MAX_LEN_BODY
            ).to(model.device) 
            
            with torch.no_grad():
                out_body = model(**inputs_body)
                emb_body = out_body.last_hidden_state # float32
                
                # Clamp to safe float16 range
                emb_body = torch.clamp(emb_body, min=-65000, max=65000)
                emb_body_np = emb_body.half().cpu().numpy()
                
            body_dset[idx] = emb_body_np[0]

            inputs_title = tokenizer(
                [text_title], return_tensors="pt", padding="max_length", 
                truncation=True, max_length=MAX_LEN_TITLE
            ).to(model.device)
            
            with torch.no_grad():
                out_title = model(**inputs_title)
                
                pooled_title = manual_mean_pooling(out_title, inputs_title.attention_mask)
                
                pooled_title = torch.clamp(pooled_title, min=-65000, max=65000)
                pooled_title_np = pooled_title.half().cpu().numpy()
                
            title_dset[idx] = pooled_title_np[0]

    print("Repair complete! No more NaNs.")

In [7]:
repair()

Memory cleaned. Starting repair...
Loading CSV...
Scanning H5 for corruption...


Scanning: 100%|██████████| 100/100 [00:24<00:00,  4.00it/s]


Found 147 corrupted samples.
Loading Model in float32 (with CPU offload)...


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.87s/it]
Some parameters are on the meta device because they were offloaded to the cpu and disk.


Repairing samples...


Fixing: 100%|██████████| 147/147 [12:28<00:00,  5.09s/it]


Repair complete! No more NaNs.
