In [1]:
import os
import sys

project_path = "/home/nlp/achimoa/workspace/hebrew_text_retrieval"
src_path = os.path.join(project_path, "src")

os.chdir(project_path)
if src_path not in sys.path:
    sys.path.append(src_path)

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
from datasets import Dataset
from transformers import AutoTokenizer, AutoModel, Trainer, TrainingArguments, AutoConfig, PreTrainedModel
import torch
from torch import nn
import torch.nn.functional as F
from src.data.heq.heq_data import HeQDatasetBuilder, HeQTaskName

In [None]:
heq_dataset_builder = HeQDatasetBuilder(task=HeQTaskName.QUESTION_DOC, decorate_with_task_tokens=False)
heq_dataset = heq_dataset_builder.build_dataset(filter_empty_answers=True, splits=['train', 'validation'])
heq_dataset

url = https://raw.githubusercontent.com/NNLP-IL/Hebrew-Question-Answering-Dataset/refs/heads/main/data/data%20v1.1/train%20v1.1.json


Filter:   0%|          | 0/4462 [00:00<?, ? examples/s]

url = https://raw.githubusercontent.com/NNLP-IL/Hebrew-Question-Answering-Dataset/refs/heads/main/data/data%20v1.1/val%20v1.1.json


Filter:   0%|          | 0/239 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['anchor_text', 'positive_text', 'index', 'paragraph_index', 'question_id', 'question', 'answer', 'context'],
        num_rows: 3198
    })
    validation: Dataset({
        features: ['anchor_text', 'positive_text', 'index', 'paragraph_index', 'question_id', 'question', 'answer', 'context'],
        num_rows: 168
    })
})

In [5]:
heq_dataset['validation'][0]

{'anchor_text': 'מה אכלו הרשב"י ובנו במערה?',
 'positive_text': 'יהודה בן גרים סיפר על דבריו, ובעקבות זאת נדון ר\' שמעון למיתה על ידי השלטון הרומאי, ונאלץ לרדת למחתרת. על פי המסורת, התחבאו רשב"י ובנו רבי אלעזר 12 שנים במערה בפקיעין, וניזונו מעץ חרוב וממעין מים שנבראו להם בדרך נס. כל אותן 12 שנה היו שניהם לומדים תורה כאשר כל גופם מכוסה חול עד צווארם, ורק בזמן התפילה יצאו מהחול והתלבשו. לאחר 12 שנה הגיע אליהו הנביא למערה והודיע לרשב"י כי קיסר רומא מת וגזרותיו בוטלו. אז יצאו רשב"י ואלעזר בנו ממקום מחבואם, אך כשראה רשב"י בעולם אנשים מתבטלים מתלמוד תורה ועוסקים בחרישה ובזריעה נתן בהם עיניו ונשרפו. אז יצאה בת קול מן השמים ואמרה "להחריב עולמי יצאתם? חזרו למערתכם!". חזרו רשב"י ואלעזר בנו למערה לעוד 12 חודשים, שבסופם יצאו מהמערה ופגשו אדם המביא לכבוד שבת שני הדסים, וכך ראו כמה חביבות מצוות על ישראל ונתקררה דעתם.',
 'index': 0,
 'paragraph_index': 0,
 'question_id': '425478ad-1fb3-4a1a-a100-230cc56e2ccf',
 'question': 'מה אכלו הרשב"י ובנו במערה?',
 'answer': 'חרוב',
 'context': 'יהודה בן גרים סי

In [11]:
query_model_name = "answerdotai/ModernBERT-base"
doc_model_name = "answerdotai/ModernBERT-base"

In [None]:
tokenizer_q = AutoTokenizer.from_pretrained(query_model_name)
tokenizer_d = AutoTokenizer.from_pretrained(doc_model_name)

def preprocess(
        example, 
        query='question', 
        paragraph='context', 
        truncation=True, 
        padding="max_length", 
        max_length=1024
    ):
    q = tokenizer_q(
        example[query], truncation=truncation, padding=padding, max_length=max_length
    )
    d = tokenizer_d(
        example[paragraph], truncation=truncation, padding=padding, max_length=max_length
    )
    return {
        "q_input_ids": q['input_ids'],
        "q_attention_mask": q['attention_mask'],
        "d_input_ids": d['input_ids'],
        "d_attention_mask": d['attention_mask'],
    }

processed = heq_dataset.map(preprocess)


Map:   0%|          | 0/3198 [00:00<?, ? examples/s]

Map:   0%|          | 0/168 [00:00<?, ? examples/s]

In [None]:
def collate_fn(batch):
    return {
        "q_input_ids": torch.tensor([item["q_input_ids"] for item in batch]),
        "q_attention_mask": torch.tensor([item["q_attention_mask"] for item in batch]),
        "d_input_ids": torch.tensor([item["d_input_ids"] for item in batch]),
        "d_attention_mask": torch.tensor([item["d_attention_mask"] for item in batch]),
    }


In [6]:
from transformers import ModernBertForMaskedLM, AutoModelForMaskedLM, AutoModel, AutoTokenizer  # adjust import path if needed
import torch

model_name_or_path = "/home/nlp/achimoa/workspace/ModernBERT/hf/HebrewModernBERT/ModernBERT-Hebrew-base_20250522_1841"
model = AutoModel.from_pretrained(model_name_or_path)
print("Loaded model:", model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
print("Loaded tokenizer:", model_name_or_path)
inputs = tokenizer("בדיקה", return_tensors="pt")
print("Inputs:", inputs)
with torch.no_grad():
    out = model(**inputs)
    print("NaN in output?", torch.isnan(out[0]).any().item())
    print(out)

Loaded model: /home/nlp/achimoa/workspace/ModernBERT/hf/HebrewModernBERT/ModernBERT-Hebrew-base_20250522_1841
Loaded tokenizer: /home/nlp/achimoa/workspace/ModernBERT/hf/HebrewModernBERT/ModernBERT-Hebrew-base_20250522_1841
Inputs: {'input_ids': tensor([[   2, 8273,    3]]), 'attention_mask': tensor([[1, 1, 1]])}
NaN in output? False
BaseModelOutput(last_hidden_state=tensor([[[ 4.7266e-02,  2.5687e-02,  1.0403e-01,  ...,  3.0027e-02,
          -2.2502e-02,  1.5941e-02],
         [-2.2235e-01,  3.2447e-01, -3.3287e-01,  ...,  4.3003e-01,
           8.0599e-02,  2.6916e-01],
         [-8.0712e-02,  4.8851e-02, -1.0924e-02,  ..., -1.1590e-01,
          -4.3202e-05,  4.0384e-02]]]), hidden_states=None, attentions=None)


In [None]:
class InfoNCEDualEncoder(PreTrainedModel):
    def __init__(self, config, query_model_name, doc_model_name=None, pooling='cls'):
        super().__init__(config)
        self.query_encoder = AutoModel.from_pretrained(query_model_name, config=config)
        if doc_model_name:
            self.doc_encoder = AutoModel.from_pretrained(doc_model_name, config=config)
        else:
            self.doc_encoder = AutoModel.from_pretrained(query_model_name, config=config)
        self.pooling = pooling

    def encode(self, encoder, input_ids, attention_mask):
        output = encoder(input_ids=input_ids, attention_mask=attention_mask)
        # [CLS] pooling or mean pooling
        if self.pooling == 'cls':
            return output.last_hidden_state[:, 0]  # [batch, hidden]
        elif self.pooling == 'mean':
            mask = attention_mask.unsqueeze(-1).expand(output.last_hidden_state.size())
            sum_emb = torch.sum(output.last_hidden_state * mask, 1)
            sum_mask = torch.clamp(mask.sum(1), min=1e-9)
            return sum_emb / sum_mask
        else:
            raise ValueError("Unknown pooling type")

    def forward(
        self,
        q_input_ids,
        q_attention_mask,
        d_input_ids,
        d_attention_mask,
        labels=None  # not used
    ):
        # [batch, hidden]
        q_emb = self.encode(self.query_encoder, q_input_ids, q_attention_mask)
        d_emb = self.encode(self.doc_encoder, d_input_ids, d_attention_mask)

        # [batch, batch] similarity matrix
        sim_matrix = torch.matmul(q_emb, d_emb.T)  # dot product, or use F.cosine_similarity

        # InfoNCE loss
        targets = torch.arange(sim_matrix.size(0), device=sim_matrix.device)
        loss = F.cross_entropy(sim_matrix, targets)

        return {"loss": loss, "logits": sim_matrix}


In [22]:
config = AutoConfig.from_pretrained(query_model_name)
model = InfoNCEDualEncoder(config, query_model_name, doc_model_name, pooling='cls')

training_args = TrainingArguments(
    output_dir="./outputs/dual_encoder/dual_encoder_infonce_heq",
    num_train_epochs=2,
    per_device_train_batch_size=8,
    learning_rate=2e-5,
    remove_unused_columns=False,
    logging_steps=10,
    save_steps=50,
    eval_strategy="steps",
    eval_steps=100,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed['train'],
    eval_dataset=processed['validation'],
    data_collator=collate_fn,
)

trainer.train()


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss
100,2.153,No log


: 

: 

In [17]:
processed

DatasetDict({
    train: Dataset({
        features: ['anchor_text', 'positive_text', 'index', 'paragraph_index', 'question_id', 'question', 'answer', 'context', 'q_input_ids', 'q_attention_mask', 'a_input_ids', 'a_attention_mask'],
        num_rows: 3198
    })
    validation: Dataset({
        features: ['anchor_text', 'positive_text', 'index', 'paragraph_index', 'question_id', 'question', 'answer', 'context', 'q_input_ids', 'q_attention_mask', 'a_input_ids', 'a_attention_mask'],
        num_rows: 168
    })
})