In [1]:
from datasets import load_dataset
dataset = load_dataset('OpenLLM-Ro/ro_arc_challenge', trust_remote_code=True)

def preprocess_splits(dataset):
    train_df = dataset['train'].to_pandas().sample(frac = 1, random_state = 1).reset_index(drop=True)
    test_df = dataset['test'].to_pandas().sample(frac = 1, random_state = 1).reset_index(drop=True)
    validation_df = dataset['validation'].to_pandas().sample(frac = 1, random_state = 1).reset_index(drop=True)

    train_df = train_df[train_df['answer'] != 'E']
    train_df = train_df[train_df['option_e'].isnull()]
    test_df = test_df[test_df['answer'] != 'E']
    test_df = test_df[test_df['option_e'].isnull()]
    validation_df = validation_df[validation_df['answer'] != 'E']
    validation_df = validation_df[validation_df['option_e'].isnull()]
    for df in (train_df, validation_df, test_df):
        df.drop(columns=["option_e"], inplace=True)

    splits = {
        "train": train_df,
        "validation": validation_df,
        "test": test_df
    }
    return splits

splits = preprocess_splits(dataset)
splits

{'train':                                                option_d  \
 0     Au fost formați din rămășițele plantelor și an...   
 1     Mai puține poluanții vor fi produși de vehicul...   
 2     patru elevi care lucrează împreună pentru a mu...   
 3                                  Ele creează energie.   
 4     Un organism care depinde de surse similare de ...   
 ...                                                 ...   
 1104  de la greutatea sedimentelor care presează asu...   
 1105                      de la Carul Mare la Carul Mic   
 1106  hidrogen și oxigen reacționând pentru a produc...   
 1107          folosește două mărgele de aceeași mărime.   
 1108  reduces procentul de căldură pierdută în atmos...   
 
                                          id  \
 0        ARC-Challenge/train/Mercury_402062   
 1        ARC-Challenge/train/MDSA_2010_5_35   
 2     ARC-Challenge/train/Mercury_SC_416461   
 3       ARC-Challenge/train/Mercury_7185448   
 4         ARC-Challenge/trai

In [2]:
import torch
from tqdm import tqdm
import unicodedata
from enum import Enum
import os

MODEL_NORMALIZATION = {
    "faur-ai/LLMic": True
}

def remove_diacritics(text):
    return ''.join(c for c in unicodedata.normalize('NFD', text) if unicodedata.category(c) != 'Mn')

class EmbeddingExtractor:
    def __init__(self, model, tokenizer, model_name, device=None, pooling="classical-avg"):
        self.model = model.to(device or ("cuda" if torch.cuda.is_available() else "cpu")).eval()
        self.tokenizer = tokenizer
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.pooling = pooling
        self.model_name = model_name
        self.model_needs_normalization = MODEL_NORMALIZATION.get(self.model_name, False)
        self.max_length = getattr(self.model.config, "max_position_embeddings")
        
    def _build_prompt(self, text, strategy):
        if strategy == "echo":
            prompt = f"Rescrie întrebarea cu patru variante de răspuns:\n\n{text}\n\nÎntrebarea rescrisă:\n\n{text}"
        elif strategy == "summary":
            prompt = f"Scrie întrebarea cu patru variante de răspuns:\n\n{text}\n\nCare este un singur cuvânt cheie relevant pentru întrebare?"
        else:
            prompt = f"Scrie întrebarea cu patru variante de răspuns:\n\n{text}"

        if self.model_needs_normalization:
            return remove_diacritics(prompt).lower()
        return prompt

    def _apply_pooling(self, hidden_states, inputs, text, strategy, pooling_method):
        input_ids = inputs["input_ids"][0]
        full_text = self.tokenizer.decode(input_ids, skip_special_tokens=True).strip()

        if self.model_needs_normalization:
            full_text = remove_diacritics(full_text).lower()

        # print(f"\n[DEBUG] Full Decoded Prompt:\n{full_text}\n")

        if strategy == "echo":
            instruction = "intrebarea rescrisa:" if self.model_needs_normalization else "Întrebarea rescrisă:"
            idx = full_text.find(instruction)
            if idx == -1:
                raise ValueError(f"Failed to find echo instruction '{instruction}' in the prompt text.")
            selected_text = full_text[idx + len(instruction):].strip()
            # print(f"[DEBUG] Selected (echo second occurrence):\n{selected_text}\n")
            second_tokens = self.tokenizer(selected_text, return_tensors="pt", truncation=True).to(self.device)
            length = second_tokens["input_ids"].shape[1] - 2
            selected = hidden_states[:, -length:, :]

        elif strategy == "summary":
            generated_ids = self.model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=10,
                do_sample=False,
                pad_token_id=self.tokenizer.eos_token_id
            )
            generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            completion = generated_text[len(full_text):].strip()
            if not completion:
                print("[WARNING] Model did not generate anything after summary instruction.")
            # print(f"[DEBUG] Generated Text (summary):\n{completion}\n")
            first_word = completion.split()[0] if completion else ""
            first_word = next((w for w in completion.split() if w.strip("-.,!?…").isalnum()), "")
            # print(f"[DEBUG] Generated Text First Word (summary):\n{first_word}\n")
           
            summary_tokens = self.tokenizer(first_word, return_tensors="pt", truncation=True).to(self.device)
            length = summary_tokens["input_ids"].shape[1] - 2
            selected = hidden_states[:, -length:, :] if length > 0 else hidden_states[:, -1:, :]

        else:
            instruction = "scrie intrebarea cu patru variante de raspuns:" if self.model_needs_normalization else "Scrie întrebarea cu patru variante de răspuns:"
            idx = full_text.find(instruction)
            if idx == -1:
                raise ValueError(f"Failed to find classical instruction '{instruction}' in the prompt text.")
            after_instruction = full_text[idx + len(instruction):].strip()
            # print(f"[DEBUG] Selected (classical real input):\n{after_instruction}\n")
            text_tokens = self.tokenizer(after_instruction, return_tensors="pt", truncation=True).to(self.device)
            length = text_tokens["input_ids"].shape[1] - 2
            selected = hidden_states[:, -length:, :]

        if pooling_method == "avg":
            return selected.mean(dim=1).squeeze()
        elif pooling_method == "last":
            return selected[:, -1, :].squeeze()
        else:
            raise ValueError(f"Unknown pooling method: {pooling_method}")


    def extract_single(self, text):
        strategy, pooling_method = self.pooling.split("-")
        prompt = self._build_prompt(text, strategy)
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True,  max_length=self.max_length).to(self.device)
        
        # print("[DEBUG] Decoded back:", self.tokenizer.decode(inputs["input_ids"][0])) # see how model internally tokenizes data (lowercase etc) llmic fara diactritice si lowercase
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        text = remove_diacritics(text).lower() if self.model_needs_normalization else text
        return self._apply_pooling(hidden_states, inputs, text, strategy, pooling_method)

    def extract_batch(self, texts, save_path=None, save_format="pt"):
        embeddings = []

        for text in tqdm(texts, desc=f"Extracting ({self.pooling})"):
            emb = self.extract_single(text)
            embeddings.append(emb.cpu())

        stacked = torch.stack(embeddings)

        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            if save_format == "pt":
                torch.save(stacked, save_path)
            elif save_format == "npy":
                import numpy as np
                np.save(save_path, stacked.numpy())

        return stacked

    def extract_in_chunks(extractor, texts, save_path, save_format="pt", batch_size=64):
        all_embeddings = []
    
        total = len(texts)
        with tqdm(total=total, desc=f"Extracting ({extractor.pooling})") as pbar:
            for i in range(0, total, batch_size):
                batch = texts[i:i + batch_size]
                batch_emb = extractor.extract_batch(batch)  
                all_embeddings.append(batch_emb)
                pbar.update(len(batch))
    
        stacked = torch.cat(all_embeddings)
        return stacked

In [3]:
import os
import numpy as np

class EmbeddingExtractionRunner:
    def __init__(self, model, tokenizer, model_name, device=None):
        self.model = model
        self.tokenizer = tokenizer
        self.model_name = model_name
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    def run(self, splits, save_root):
        strategies = ["classical-avg", "classical-last", "echo-avg", "summary-avg"]

        for strategy in strategies:
            extractor = EmbeddingExtractor(
                model=self.model,
                tokenizer=self.tokenizer,
                model_name=self.model_name,
                device=self.device,
                pooling=strategy
            )

            for split_name, split_data in splits.items():
                texts = [ 
                    f"{row['instruction']}\nA. {row['option_a']}\nB. {row['option_b']}\nC. {row['option_c']}\nD. {row['option_d']}"
                    for _, row in split_data.iterrows()
                ]

                base_dir = os.path.join(
                    save_root,
                    self.model_name,
                    strategy,
                    split_name
                )
                os.makedirs(base_dir, exist_ok=True)

                npy_path = os.path.join(base_dir, "embeddings.npy")

                embeddings = extractor.extract_batch(texts)
                np.save(npy_path, embeddings.cpu().numpy())

                del embeddings
                torch.cuda.empty_cache()
                import gc
                gc.collect()

In [None]:
from huggingface_hub import interpreter_login

interpreter_login()

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
MODELS = [
    "meta-llama/Llama-3.1-8B-Instruct",
    "OpenLLM-Ro/RoLlama3.1-8b-Instruct",
    "ai-forever/mGPT-1.3B-romanian",
    "faur-ai/LLMic"
]

Debug to make sure everything works for all models

In [None]:
strategies = ["classical-avg", "classical-last", "echo-avg", "summary-avg"]
for model_name in MODELS:
    print(f"MODEL: {model_name}")
    for strategy in strategies:
        model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        extractor = EmbeddingExtractor(
        model=model,
        tokenizer=tokenizer,
        model_name=model_name,
        device="cuda",
        pooling=strategy
        )
        
        row = splits["train"].iloc[0]
        text = f"{row['instruction']}\nA. {row['option_a']}\nB. {row['option_b']}\nC. {row['option_c']}\nD. {row['option_d']}"
        
        embedding = extractor.extract_single(text)
        
        print(f"Embedding shape: {embedding.shape}")
        print(f"Embedding preview: {embedding[:10]}")
        

Extraction

In [None]:
for model_name in MODELS:
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    runner = EmbeddingExtractionRunner(model, tokenizer, model_name)
    runner.run(splits, save_root="roarc_embeddings")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Extracting (classical-avg): 100% 1108/1108 [01:04<00:00, 17.20it/s]
Extracting (classical-avg): 100% 296/296 [00:17<00:00, 16.63it/s]
Extracting (classical-avg): 100% 1164/1164 [01:09<00:00, 16.80it/s]
Extracting (classical-last): 100% 1108/1108 [01:04<00:00, 17.18it/s]
Extracting (classical-last): 100% 296/296 [00:17<00:00, 16.64it/s]
Extracting (classical-last): 100% 1164/1164 [01:09<00:00, 16.81it/s]
Extracting (echo-avg): 100% 1108/1108 [01:43<00:00, 10.66it/s]
Extracting (echo-avg): 100% 296/296 [00:28<00:00, 10.34it/s]
Extracting (echo-avg): 100% 1164/1164 [01:51<00:00, 10.40it/s]
Extracting (summary-avg): 100% 1108/1108 [04:59<00:00,  3.70it/s]
Extracting (summary-avg): 100% 296/296 [01:21<00:00,  3.64it/s]
Extracting (summary-avg):  90% 1051/1164 [04:46<00:32,  3.48it/s]