In [1]:
import os
import sys
import pandas as pd
import numpy as np
from tqdm import tqdm
import re
import gc
import unicodedata
import string

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils.data import DataLoader, TensorDataset

sys.path.append("..")
from src.utils import get_data

In [None]:
# Set device to be CUDA 
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"  # Adjust based on your available GPUs
print("Using CUDA device:", os.environ["CUDA_VISIBLE_DEVICES"])

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

os.environ["HF_TOKEN"] = ""

Using CUDA device: 0


In [None]:
LANG = "italian"
TASK = "dodiom"
SPLIT = "train"  # "train" or "test"

NUM_SAMPLES = 100  # 0 for all
# MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct"

OUTPUT_DIR = "outputs"
BSZ = 16
CHECKPOINT_INTERVAL = 20  
LAYER_WISE_PROBING = True  # set to False to use single-layer span vector
NO_GAP_WORDS = True  # If True, will only consider words in the idiom that are not separated by gaps (e.g., "kick the bucket" -> "kick bucket")

In [100]:
num_samples_str = "all" if NUM_SAMPLES <= 0 else str(NUM_SAMPLES)

output_dir = os.path.join(OUTPUT_DIR, LANG, TASK)
model_name = MODEL_PATH.split("/")[-1]
layer_wise = "all_hidden" if LAYER_WISE_PROBING else "last_hidden"
output_file = os.path.join(output_dir, f"{layer_wise}_{SPLIT}_{num_samples_str}_{model_name}.jsonl")
output_file_vectors = os.path.join(output_dir, f"{layer_wise}_{SPLIT}_{num_samples_str}_{model_name}_vectors.pt")

# Make sure the output directory exists
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Check if files exist and load existing data
existing_ids = set()
existing_vectors = torch.tensor([])  # Initialize as empty tensor
if os.path.exists(output_file):
    print(f"Found existing output file: {output_file}")
    existing_data = pd.read_json(output_file, lines=True)
    existing_ids = set(existing_data['id'].tolist())
    print(f"Found {len(existing_ids)} existing processed samples")
    
    if os.path.exists(output_file_vectors):
        print(f"Found existing vectors file: {output_file_vectors}")
        existing_vectors = torch.load(output_file_vectors)
        print(f"Loaded {existing_vectors.shape[0]} existing vectors")
    else:
        print("Warning: JSONL file exists but vectors file is missing. Starting fresh.")
        existing_ids = set()
        # Create an empty tensor for vectors with 3 dimensions
        existing_vectors = None


In [101]:
# Get data 
data = get_data(lang=LANG, task=TASK)[SPLIT]


✅ Loaded train for task='dodiom', lang='italian': (7033, 10)
✅ Loaded test for task='dodiom', lang='italian': (500, 10)
✅ Loaded validation for task='dodiom', lang='italian': empty


In [102]:
def extract_japanese_idiom_tokens(row):
    tokens = row["tokens"]
    pie = row["pie"]
    sentence = row["sentence"]


    # Find idiom char span
    match = re.search(re.escape(pie), sentence)
    if not match:
        print(f"Idiom '{pie}' not found in sentence: {sentence}")
        return []
    idiom_start = match.start()
    idiom_end = match.end()

    # Compute character span for each token
    token_char_spans = []
    cursor = 0
    for token in tokens:
        start = cursor
        end = cursor + len(token)
        token_char_spans.append((start, end))
        cursor = end

    # Select tokens whose span overlaps the idiom span
    selected_tokens = [
        tokens[i]
        for i, (start, end) in enumerate(token_char_spans)
        if not (end <= idiom_start or start >= idiom_end)
    ]

    # Select tokens that are in the idiom span
    # selected_tokens = [
    #     token for token in selected_tokens if token in pie
    # ]

    return selected_tokens


In [103]:
if TASK == "magpie":
    # # Remove idiom_tokens
        data = data.drop(columns=["idiom_tokens"], errors='ignore')
        # New idiom_tokens is created from pie_tokens
        data = data.rename(columns={"pie_tokens": "idiom_tokens"})
        data["idiomatic"] = data["label"].apply(lambda x: True if x == "idiomatic" else False)
elif TASK == "dodiom":
    # Add fake id column
    data["id"] = data.index
    # Add split column
    data["split"] = SPLIT
    # Rename columns
    data = data.rename(columns={"idiom": "idiom_base", "idiom_words": "idiom_tokens"})

    # Add pie column
    data["pie"] = data["true_idioms"].apply(lambda x: x[0])

    # Add idiomatic column
    data["idiomatic"] = data["category"].apply(lambda x: True if x == "idiom" else False)

elif TASK == "open_mwe":
    # Add fake id column
    data["id"] = data.index
    # Add split column
    data["split"] = SPLIT
    # Create new column idiom_tokens
    data['idiom_tokens'] = data.apply(extract_japanese_idiom_tokens, axis=1)
    
else:
    raise ValueError(f"Unsupported task: {TASK}")

# Keep only necessary columns and store full dataset
# full_data = data[["id", "split", "sentence", "idiom_base", "pie", "idiomatic", "idiom_tokens"]].copy()
full_data = data.copy()

print("Data loaded. Number of samples:", len(full_data))

Data loaded. Number of samples: 7033


In [104]:
print("Sample results:")
for i in range(min(5, len(data))):
    print(f"Row {i}:")
    print(f"  Original tokens: {data.iloc[i]['tokens']}...")  # First 10 tokens
    print(f"  Idiom span (pie): {data.iloc[i]['pie']}")
    print(f"  Extracted idiom tokens: {data.iloc[i]['idiom_tokens']}")
    print()

Sample results:
Row 0:
  Original tokens: ['A', 'Clelia', 'piace', 'passeggiare', 'per', 'i', 'campi', 'ed', 'acchiappare', 'quante', 'più', 'farfalle', 'possibile', 'con', 'il', 'suo', 'retino', ',', 'per', 'poi', 'lasciarle', 'libere']...
  Idiom span (pie): acchiappare farfalle
  Extracted idiom tokens: ['acchiappare', 'farfalle']

Row 1:
  Original tokens: ['A', 'me', 'sembra', 'si', 'stare', 'acchiappando', 'solo', 'farfalle', 'seguendo', 'i', 'tuoi', 'consigli']...
  Idiom span (pie): acchiappando farfalle
  Extracted idiom tokens: ['acchiappando', 'farfalle']

Row 2:
  Original tokens: ['Al', 'posto', 'di', 'acchiappare', 'farfalle', ',', 'pensa', 'a', 'studiare', '!']...
  Idiom span (pie): acchiappare farfalle
  Extracted idiom tokens: ['acchiappare', 'farfalle']

Row 3:
  Original tokens: ['Basta', 'acchiappare', 'le', 'farfalle']...
  Idiom span (pie): acchiappare farfalle
  Extracted idiom tokens: ['acchiappare', 'farfalle']

Row 4:
  Original tokens: ['Chi', 'vuole', 'gioc

In [105]:
if existing_ids:
    original_size = len(full_data)
    # Filter out already processed samples
    data = full_data[~full_data['id'].isin(existing_ids)]
    print(f"Filtered out {original_size - len(data)} already processed samples")
    if NUM_SAMPLES > 0:
        remaining2process = NUM_SAMPLES - len(existing_ids)
        if remaining2process < 0:
            print(f"Warning: More samples already processed than requested ({len(existing_ids)} vs {NUM_SAMPLES}).")
            remaining2process = 0
    else:
        remaining2process = len(data) - len(existing_ids)

    print(f"Requested total samples to process: {NUM_SAMPLES if NUM_SAMPLES > 0 else original_size}")
    print(f"Remaining samples to process: {remaining2process}")
    
    if len(data) == 0:
        print("All samples already processed!")
        # return
else:
    data = full_data.copy()
    remaining2process = NUM_SAMPLES if NUM_SAMPLES > 0 else len(data)


# Debug, use a small subset
if NUM_SAMPLES > 0:
    print("Data size before cutting:", len(data))
    indices = np.random.choice(data.index, size=min(remaining2process, len(data)), replace=False)
    data = data.loc[indices].reset_index(drop=True)
    print("Data size after filtering:", len(data))


Data size before cutting: 7033
Data size after filtering: 100


In [106]:
data.head(2)

Unnamed: 0,idiom_base,sentence,category,idiom_indices,idiom_tokens,tokens,true_idioms,tags,tag_ids,idiom_full_span,id,split,pie,idiomatic
0,Prendere con le pinze,Prendo quella cosa con le pinze per una maggio...,nonidiom,"[0, 3, 4, 5]","[Prendo, con, le, pinze]","[Prendo, quella, cosa, con, le, pinze, per, un...",[Prendo con le pinze],"[O, O, O, O, O, O, O, O, O, O]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",Prendo quella cosa con,4610,train,Prendo con le pinze,False
1,Brancolare nel buio,"Non brancolare così nel buio, chiedi aiuto ai ...",idiom,"[1, 3, 4]","[brancolare, nel, buio]","[Non, brancolare, così, nel, buio, ,, chiedi, ...",[brancolare nel buio],"[O, B-IDIOM, I-IDIOM, I-IDIOM, O, O, O, O, O, ...","[0, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",brancolare così nel,1976,train,brancolare nel buio,True


In [107]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto")
model.eval()  # Set the model to evaluation mode

# Check on which device the model is loaded
device = next(model.parameters()).device
print(f"Model is loaded on device: {device}")

Model is loaded on device: cuda:0


In [108]:
# Tokenize the sentences
tokenized = tokenizer(
    data["sentence"].tolist(),
    return_offsets_mapping=True,
    return_tensors="pt",
    padding=True,
    truncation=True
)

print("Tokenization complete. Number of sentences tokenized:", len(tokenized["input_ids"]))

# Save separately for convenience
offset_mapping = tokenized["offset_mapping"]


Tokenization complete. Number of sentences tokenized: 100


In [109]:
# Step 1: Tokenization (store offset_mapping separately)
tokenized_inputs = {k: v for k, v in tokenized.items() if k != "offset_mapping"}
keys = list(tokenized_inputs.keys())

# Step 2: Batch processing using DataLoader
dataset = TensorDataset(*[tokenized_inputs[k] for k in keys])

dataloader = DataLoader(dataset, batch_size=BSZ)

In [110]:
# Step 3: Initialize mean_vectors and processed_data
mean_vectors = existing_vectors
processed_count = len(existing_ids) if existing_ids else 0

# Track all processed data for checkpointing
if existing_ids:
    # Load existing processed data for checkpointing
    existing_data = pd.read_json(output_file, lines=True)
    processed_data = existing_data.copy()
else:
    processed_data = pd.DataFrame()

In [111]:
def normalize(text):
    """Normalize without over-aggressive replacements. Lowercase and strip control chars."""
    text = unicodedata.normalize("NFKC", text)
    text = text.replace("’", "").replace("‘", "").replace("“", '"').replace("”", '"')
    text = text.replace("`", "'").replace("'", "")  # optionally keep apostrophes
    text = ''.join(c for c in text if not unicodedata.category(c).startswith('C'))
    return text.strip().lower()



def group_subword_matches(decoded_tokens, target_tokens):
    selected_indices = []
    i = 0
    while i < len(decoded_tokens):
        match_found = False
        for j in range(len(decoded_tokens), i, -1):
            candidate_tokens = decoded_tokens[i:j]
            candidate_text = ''.join([t[0] for t in candidate_tokens])
            for target in target_tokens:
                if candidate_text == target:
                    selected_indices.extend([t[1] for t in candidate_tokens])
                    target_tokens.remove(target)
                    i = j - 1
                    match_found = True
                    break
            if match_found:
                break
        i += 1
    return selected_indices




def clean_bad_unicode(word):
    # Fix cases like "motionsâģļ" => strip junk and keep core
    return normalize(re.sub(r"[^a-zA-Z0-9]+$", "", word))

def match_japanese_span_tokens(decoded_tokens: list[tuple[str, int]], span_words: list[str]) -> list[int]:
    """
    Match Japanese span by comparing the full target span string to concatenated decoded tokens.
    
    Parameters:
    - decoded_tokens: list of (decoded_token, token_index)
    - span_words: list of original words from the span (may include kana/kanji sequences)
    
    Returns:
    - List of token indices that correspond to the span
    """
    target_string = normalize(''.join(span_words))
    
    for start in range(len(decoded_tokens)):
        for end in range(start + 1, len(decoded_tokens) + 1):
            candidate = ''.join(tok for tok, _ in decoded_tokens[start:end])
            if normalize(candidate) == target_string:
                return [idx for _, idx in decoded_tokens[start:end]]
    
    return []


In [112]:
def extract_span_vector(
    lang,
    tokenizer,
    row_idx,
    sentence,
    tokens,
    offsets,
    hidden_states,
    span_no_gap_words,
    span_words,
    use_all_layers=False,
) -> torch.Tensor:
    
    # Fallback: span is missing
    if not span_no_gap_words:
        print("No span words provided.")
        dim = hidden_states[0].shape[-1]
        return torch.zeros((len(hidden_states), dim)) if use_all_layers else torch.zeros((1, dim))
    
    token_list = tokenizer.convert_ids_to_tokens(tokens[row_idx])
    offset_list = offsets[row_idx]

    corrected_offsets = []
    for token, (start, end) in zip(token_list, offset_list):
        if token is None:
            corrected_offsets.append((0, 0))
            continue
        if token.startswith("Ġ") or token.startswith("▁"):
            corrected_offsets.append((start + 1, end))
        else:
            corrected_offsets.append((start, end))

    # Match span
    match = re.search(re.escape(span_words), sentence)
    if not match:
        dim = hidden_states[0].shape[-1]
        return torch.zeros((len(hidden_states), dim)) if use_all_layers else torch.zeros((1, dim))

    char_span = (match.start(), match.end())

    span_token_indices = [
        i for i, (start, end) in enumerate(corrected_offsets)
        if start >= char_span[0] and end <= char_span[1]
    ]

    if not span_token_indices:
        dim = hidden_states[0].shape[-1]
        return torch.zeros((len(hidden_states), dim)) if use_all_layers else torch.zeros((1, dim))

    target_tokens = [normalize(t) for t in span_no_gap_words]


    target_tokens = [tokenizer.decode(tokenizer.encode(t, add_special_tokens=False)).strip().lower() for t in target_tokens]
    
    decoded_tokens = [
        (normalize(tokenizer.convert_tokens_to_string([token_list[i]])), i)
        for i in span_token_indices
    ]


    # If Japanese
    if lang == "japanese":
        selected_indices = match_japanese_span_tokens(decoded_tokens, span_no_gap_words)
        chosen_tokens = [tok for tok, idx in decoded_tokens if idx in selected_indices]
    else:
        # Match subword groups
        selected_indices = group_subword_matches(decoded_tokens, target_tokens.copy())

    chosen_tokens = [tokenizer.convert_tokens_to_string([token_list[i]]).strip().lower() for i in selected_indices]


    if not selected_indices:
        dim = hidden_states[0].shape[-1]
        return torch.zeros((len(hidden_states), dim)) if use_all_layers else torch.zeros((1, dim))

    # Remove punctuation-only tokens from end
    while selected_indices:
        idx = selected_indices[-1]
        tok = tokenizer.convert_tokens_to_string([token_list[idx]]).strip()
        if re.fullmatch(r"\W+", tok):  # only non-word characters
            selected_indices.pop()
            chosen_tokens.pop()
        else:
            break

    # --- Debugging ---
    # Normalize and compare the matched tokens against the expected ones
    matched_text = ''.join(chosen_tokens).replace(" ", "")
    expected_text = ''.join(target_tokens).replace(" ", "")
    if matched_text != expected_text:
        print(f"⚠️ Mismatch in span words for row {row_idx}:")

        print(f"sentence: {sentence}")
        print(f"span_no_gap_words: {span_no_gap_words}")
        print(f"span_token_indices: {span_token_indices}")
        print(f"Target tokens: {target_tokens}")
        print(f"Decoded tokens: {decoded_tokens}")
        print(f"chosen_tokens: {chosen_tokens}")
        print(f"selected indices: {selected_indices}")
        print()

    # --- Extract vector(s) ---
    if use_all_layers:
        layer_means = []
        for layer in hidden_states:
            span_vecs = layer[row_idx, selected_indices, :]
            mean_vec = span_vecs.mean(dim=0)
            layer_means.append(mean_vec.to(dtype=torch.float32).cpu())
        return torch.stack(layer_means, dim=0)
    else:
        span_vecs = hidden_states[0][row_idx, selected_indices, :]
        mean_vec = span_vecs.mean(dim=0)
        return mean_vec.unsqueeze(0).to(dtype=torch.float32).cpu()


In [113]:
model.eval()
with torch.inference_mode():
    for i, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
        batch_dict = {k: v.to(model.device) for k, v in zip(keys, batch)}
        outputs = model(
            **batch_dict,
            output_hidden_states=True,
            output_attentions=False,
            use_cache=False  # disables `past_key_values`
        )
        del batch_dict

        if LAYER_WISE_PROBING:
            hidden_states = outputs.hidden_states  # tuple/list of tensors
        else:
            # Wrap the last hidden state in a list to maintain consistent structure
            hidden_states = [outputs.hidden_states[-1]]

        del outputs

        # Determine the actual batch slice in the dataset
        start_idx = i * BSZ
        # FIX: Access the actual tensor shape, not the list/tuple shape
        batch_size = hidden_states[0].shape[0]  # Get batch size from the first layer's tensor
        end_idx = start_idx + batch_size
        offset_batch = offset_mapping[start_idx:end_idx]
        token_batch = tokenized["input_ids"][start_idx:end_idx]

        for j in range(batch_size):
            row = data.iloc[start_idx + j]

            mean_vec = extract_span_vector(
                lang=LANG,
                tokenizer=tokenizer,
                row_idx=j,
                span_words=row["pie"],
                sentence=row["sentence"],
                tokens=token_batch,
                offsets=offset_batch,
                hidden_states=hidden_states,
                span_no_gap_words=row["idiom_tokens"],
                # no_gap_words=NO_GAP_WORDS,
                use_all_layers=LAYER_WISE_PROBING,
            )
               

Processing batches: 100%|██████████| 7/7 [00:02<00:00,  3.43it/s]
