### Encoder–Decoder with ft-HeBERT


In [14]:
import torch
import torch.nn as nn
from transformers import AutoModel, BartForConditionalGeneration, BartTokenizer, BartConfig

# Load pretrained HeBERT (encoder)
hebert_encoder = AutoModel.from_pretrained("./hebert-mlm-3k-drugs/final")

# Load pretrained BART (decoder)
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
decoder = bart_model.model.decoder
bart_config = bart_model.config

# Ensure hidden sizes match (HeBERT: 768, BART: 768) – already aligned
assert hebert_encoder.config.hidden_size == bart_config.d_model


Some weights of BertModel were not initialized from the model checkpoint at ./hebert-mlm-3k-drugs/final and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
import os
import torch
import torch.nn as nn
from transformers import (
    AutoModel,
    BertTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
)
from datasets import load_dataset
from transformers.modeling_outputs import Seq2SeqModelOutput

# ========== GPU Debug ==========
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========== Load Tokenizer ==========
tokenizer = BertTokenizer.from_pretrained("./hebert-mlm-3k-drugs/final")
vocab_size = len(tokenizer)

# ========== Load Dataset ==========
dataset = load_dataset("csv", data_files={"train": "/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/summarization_dataset.csv"})

def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["extracted_gpt_facts"],
        max_length=512,
        padding="max_length",
        truncation=True
    )
    labels = tokenizer(
        text_target=examples["summary"],
        max_length=64,
        padding="max_length",
        truncation=True
    )
    labels["input_ids"] = [
        [(token if token != tokenizer.pad_token_id else -100) for token in label]
        for label in labels["input_ids"]
    ]
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)

# ========== Safety Check ==========
def check_token_range(dataset, vocab_size, field="labels"):
    for i, ex in enumerate(dataset):
        tokens = ex[field]
        for t in tokens:
            if t != -100 and (t >= vocab_size or t < 0):
                print(f"🚨 Invalid token at example {i}: token={t}, vocab_size={vocab_size}")
                print("Tokens:", tokens)
                print("Decoded:", tokenizer.convert_ids_to_tokens([x if x != -100 else tokenizer.pad_token_id for x in tokens]))
                raise ValueError("Invalid token index found!")
check_token_range(tokenized_dataset["train"], vocab_size)

# ========== Model ==========
class BertSeq2Seq(nn.Module):
    def __init__(self, encoder, vocab_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.encoder = encoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_size, nhead=8, dropout=dropout)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.lm_head = nn.Linear(hidden_size, vocab_size)
        self.pad_token_id = tokenizer.pad_token_id

    def forward(self, input_ids, attention_mask, labels=None):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        if labels is not None:
            labels = labels.to(device)
            decoder_input_ids = self.shift_tokens_right(labels)
            if decoder_input_ids.max() >= self.embedding.num_embeddings:
                print("🛑 INVALID INDEX IN DECODER")
                print("Max index:", decoder_input_ids.max().item())
                print("Embedding size:", self.embedding.num_embeddings)
                raise ValueError("Decoder token index out of embedding bounds")

            decoder_embeds = self.embedding(decoder_input_ids).transpose(0, 1)
            memory = encoder_outputs.transpose(0, 1)
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(decoder_embeds.size(0)).to(decoder_embeds.device)

            decoder_outputs = self.decoder(decoder_embeds, memory, tgt_mask=tgt_mask)
            logits = self.lm_head(decoder_outputs.transpose(0, 1))

            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        else:
            loss = None
            logits = None

        return Seq2SeqModelOutput(
            loss=loss,
            logits=logits,
            encoder_last_hidden_state=encoder_outputs
        )

    def shift_tokens_right(self, labels):
        labels = labels.clone()
        labels[labels == -100] = self.pad_token_id
        shifted = labels.new_zeros(labels.shape)
        shifted[:, 1:] = labels[:, :-1]
        shifted[:, 0] = self.pad_token_id
        return shifted

# ========== Initialize Model ==========
encoder = AutoModel.from_pretrained("./hebert-mlm-3k-drugs/final")
model = BertSeq2Seq(encoder, vocab_size=vocab_size, hidden_size=encoder.config.hidden_size).to(device)

# ========== Training Arguments ==========
training_args = Seq2SeqTrainingArguments(
    output_dir="./bert_summarizer",
    per_device_train_batch_size=4,
    num_train_epochs=5,
    save_strategy="epoch",
    logging_strategy="epoch",
    eval_strategy="no",  # <- fixed name
    report_to="none",
    seed=42
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=None)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    data_collator=data_collator
)

# ========== Train ==========
trainer.train()
trainer.save_model("./bert_summarizer/final")
tokenizer.save_pretrained("./bert_summarizer/final")


Map:   0%|          | 0/2987 [00:00<?, ? examples/s]

Some weights of BertModel were not initialized from the model checkpoint at ./hebert-mlm-3k-drugs/final and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
