### train mlm

In [None]:
import pandas as pd
import torch
import random
import numpy as np
from transformers import (
    MT5ForConditionalGeneration, 
    MT5Tokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from datasets import Dataset
from typing import List, Tuple
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

def split_to_windows(text: str, tokenizer, max_length: int = 256, stride: int = 128) -> List[str]:
    """
    Split text into overlapping windows based on token count
    """
    # Tokenize the text
    tokens = tokenizer.tokenize(text)
    
    # If text is shorter than max_length, return as is
    if len(tokens) <= max_length:
        return [text]
    
    windows = []
    start = 0
    
    while start < len(tokens):
        end = min(start + max_length, len(tokens))
        
        # Extract window tokens
        window_tokens = tokens[start:end]
        
        # Convert back to text
        window_text = tokenizer.convert_tokens_to_string(window_tokens)
        windows.append(window_text)
        
        # Move start position by stride
        start += stride
        
        # If we've covered all tokens, break
        if end >= len(tokens):
            break
    
    return windows

class MLMDataCollator:
    def __init__(self, tokenizer, mlm_probability=0.15, max_length=256):
        self.tokenizer = tokenizer
        self.mlm_probability = mlm_probability
        self.max_length = max_length
    
    def __call__(self, examples):
        # Prepare inputs for span corruption (T5/mT5 style)
        input_texts = []
        target_texts = []
        
        for example in examples:
            text = example['text']  
            
            # Create span corruption
            input_text, target_text = self.create_span_corruption(text)
            
            input_texts.append(input_text)
            target_texts.append(target_text)
        
        # Tokenize inputs and targets
        model_inputs = self.tokenizer(
            input_texts, 
            truncation=True, 
            padding=True, 
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        labels = self.tokenizer(
            target_texts,
            truncation=True,
            padding=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        # Replace padding token id with -100 to ignore in loss calculation
        labels["input_ids"][labels["input_ids"] == self.tokenizer.pad_token_id] = -100
        model_inputs["labels"] = labels["input_ids"]
        
        return model_inputs
    
    def create_span_corruption(self, text: str) -> Tuple[str, str]:
        """Create span corruption similar to T5 pre-training"""
        tokens = self.tokenizer.tokenize(text)
        
        if len(tokens) == 0:
            return text, ""
        
        # Calculate number of spans to mask (roughly 15% of tokens in spans)
        total_mask_length = max(1, int(len(tokens) * self.mlm_probability))
        
        # Create spans to mask
        masked_tokens = tokens.copy()
        target_spans = []
        
        current_span_id = 0
        i = 0
        
        while i < len(tokens) and current_span_id * 3 < total_mask_length:  # Average span length ~3
            if random.random() < 0.15:  # Probability of starting a span
                # Determine span length (1-5 tokens)
                span_length = min(random.randint(1, 5), len(tokens) - i)
                
                # Extract original tokens for target
                original_span = tokens[i:i + span_length]
                original_text = self.tokenizer.convert_tokens_to_string(original_span)
                
                # Create target span
                target_span = f"<extra_id_{current_span_id}> {original_text}"
                target_spans.append(target_span)
                
                # Replace with sentinel token in input
                masked_tokens[i:i + span_length] = [f"<extra_id_{current_span_id}>"]
                
                current_span_id += 1
                i += span_length
            else:
                i += 1
        
        # Add final sentinel to target
        if target_spans:
            target_spans.append(f"<extra_id_{current_span_id}>")
        
        # Convert back to text
        input_text = self.tokenizer.convert_tokens_to_string(masked_tokens)
        target_text = " ".join(target_spans) if target_spans else ""
        
        return input_text, target_text

def prepare_dataset(df: pd.DataFrame, tokenizer, text_column: str = 'extracted_gpt_facts') -> Dataset:
    """
    Prepare dataset from DataFrame with windowing
    """
    all_windows = []
    
    print(f"Processing {len(df)} texts...")
    
    for idx, row in df.iterrows():
        text = str(row[text_column])
        
        # Skip empty texts
        if not text or text.strip() == '':
            continue
            
        # Split text into windows
        windows = split_to_windows(text, tokenizer, max_length=256, stride=128)
        all_windows.extend(windows)
        
        if (idx + 1) % 100 == 0:
            print(f"Processed {idx + 1} texts, generated {len(all_windows)} windows")
    
    print(f"Total windows generated: {len(all_windows)}")
    
    # Create dataset
    dataset = Dataset.from_dict({"text": all_windows})
    return dataset

def main():
    # Set random seeds for reproducibility
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    
    # Load your data
    print("Loading data...")
    df = pd.read_csv("/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/processed_verdicts_with_gpt.csv")
    print(f"Loaded {len(df)} rows")
    
    # Initialize model and tokenizer
    model_name = "google/mt5-base"  
    print(f"Loading model and tokenizer: {model_name}")
    
    # tokenizer = MT5Tokenizer.from_pretrained(model_name)
    # model = MT5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained("imvladikon/het5-base")
    model = AutoModelForSeq2SeqLM.from_pretrained("imvladikon/het5-base")  
    # Add extra sentinel tokens for span corruption
    sentinel_tokens = [f"<extra_id_{i}>" for i in range(100)]
    num_added = tokenizer.add_tokens(sentinel_tokens)
    print(f"Added {num_added} sentinel tokens")
    
    # Resize model embeddings
    model.resize_token_embeddings(len(tokenizer))
    
    # Prepare dataset with windowing
    dataset = prepare_dataset(df, tokenizer, text_column='extracted_gpt_facts')
    
    # Split dataset (80% train, 20% validation)
    train_size = int(0.8 * len(dataset))
    eval_size = len(dataset) - train_size
    
    train_dataset = dataset.select(range(train_size))
    eval_dataset = dataset.select(range(train_size, train_size + eval_size))
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Eval dataset size: {len(eval_dataset)}")
    
    # Data collator
    data_collator = MLMDataCollator(tokenizer, mlm_probability=0.15, max_length=256)
    
 # Training arguments
    training_args = TrainingArguments(
    output_dir = "./mt5-mlm-trained",
    remove_unused_columns=False,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    fp16=False,
    evaluation_strategy="steps",
    eval_steps=500,
    per_device_eval_batch_size=1,
    gradient_checkpointing=True,
    num_train_epochs=5,
    logging_dir="./logs",
    logging_steps=100,
    disable_tqdm=False,
    log_level='info',
    save_strategy="no")
    
    
    # Initialize trainer
    print("Initializing trainer...")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )
    
    
    # Train the model
    print("Starting training...")
    trainer.train()
    
    # Save the final model
    final_output_dir = "./m5-mlm-final"
    print(f"Saving final model to {final_output_dir}")
    trainer.save_model(final_output_dir)
    tokenizer.save_pretrained(final_output_dir)
    
    print("Training completed!")
    
    # Optional: Test the trained model
    print("Testing trained model...")
    test_text = "This is a sample text to test the trained model."
    
    # # Load the saved model for testing
    # # trained_model = MT5ForConditionalGeneration.from_pretrained(final_output_dir)
    # # trained_tokenizer = MT5Tokenizer.from_pretrained(final_output_dir)
    # trained_tokenizer = AutoTokenizer.from_pretrained("imvladikon/het5-base")
    # trained_model = AutoModelForSeq2SeqLM.from_pretrained("imvladikon/het5-base")  
    # # Create a corrupted version for testing
    # collator = MLMDataCollator(trained_tokenizer)
    # input_text, target_text = collator.create_span_corruption(test_text)
    
    # print(f"Original: {test_text}")
    # print(f"Corrupted: {input_text}")
    # print(f"Target: {target_text}")
    
    # # Generate prediction
    # inputs = trained_tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True)
    
    # with torch.no_grad():
    #     outputs = trained_model.generate(
    #         inputs.input_ids,
    #         max_length=256,
    #         num_beams=4,
    #         early_stopping=True,
    #         pad_token_id=trained_tokenizer.pad_token_id
    #     )
    
    # prediction = trained_tokenizer.decode(outputs[0], skip_special_tokens=True)
    # print(f"Prediction: {prediction}")

if __name__ == "__main__":
    main()

In [1]:
import math, random
import pandas as pd
import torch
from torch.utils.data import DataLoader
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import List

# ---------- config ----------
CSV_PATH         = "/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/processed_verdicts_with_gpt.csv"
BASE_MODEL_DIR   = "imvladikon/het5-base"               # base HET5
FT_MODEL_DIR     = "/home/liorkob/M.Sc/thesis/t5/het5-mlm-final"  # your fine-tuned HET5
TEXT_COL         = "extracted_gpt_facts"
MAX_LEN          = 256
STRIDE           = 128
MLM_PROB         = 0.15
BATCH_SIZE       = 4
SEED             = 42

# ---------- helpers ----------
def set_seeds(seed: int = 42):
    random.seed(seed); torch.manual_seed(seed)

def split_to_windows(text: str, tokenizer, max_length=256, stride=128) -> List[str]:
    toks = tokenizer.tokenize(text)
    if len(toks) <= max_length:
        return [text]
    out, start = [], 0
    while start < len(toks):
        end = min(start + max_length, len(toks))
        out.append(tokenizer.convert_tokens_to_string(toks[start:end]))
        start += stride
        if end >= len(toks): break
    return out

def load_windows(csv_path: str, tokenizer, text_col: str) -> Dataset:
    df = pd.read_csv(csv_path)
    wins = []
    for _, r in df.iterrows():
        txt = str(r.get(text_col, "") or "").strip()
        if not txt: continue
        wins.extend(split_to_windows(txt, tokenizer, MAX_LEN, STRIDE))
    return Dataset.from_dict({"text": wins})

class MLMDataCollator:
    def __init__(self, tokenizer, mlm_probability=0.15, max_length=256):
        self.tok = tokenizer
        self.p = mlm_probability
        self.max_len = max_length

    def __call__(self, examples):
        inputs, targets = [], []
        for ex in examples:
            inp, tgt = self._span_corrupt(ex["text"])
            inputs.append(inp); targets.append(tgt)
        model_inputs = self.tok(inputs, truncation=True, padding=True,
                                max_length=self.max_len, return_tensors="pt")
        labels = self.tok(targets, truncation=True, padding=True,
                          max_length=self.max_len, return_tensors="pt")["input_ids"]
        labels[labels == self.tok.pad_token_id] = -100
        model_inputs["labels"] = labels
        return model_inputs

    def _span_corrupt(self, text: str):
        toks = self.tok.tokenize(text)
        if not toks: return text, ""
        total_mask = max(1, int(len(toks) * self.p))
        masked = toks.copy(); targets = []; sid, i = 0, 0
        while i < len(toks) and sid * 3 < total_mask:
            if random.random() < 0.15:
                span_len = min(random.randint(1, 5), len(toks) - i)
                orig = self.tok.convert_tokens_to_string(toks[i:i+span_len])
                targets.append(f"<extra_id_{sid}> {orig}")
                masked[i:i+span_len] = [f"<extra_id_{sid}>"]
                sid += 1; i += span_len
            else:
                i += 1
        if targets:
            targets.append(f"<extra_id_{sid}>")
        return (self.tok.convert_tokens_to_string(masked),
                " ".join(targets) if targets else "")

def build_frozen_batches(tokenizer, csv_path: str, batch_size=4):
    """
    Create one fixed set of masked batches ONCE using the base tokenizer.
    Reuse these exact tensors for all models for a fair comparison.
    """
    set_seeds(SEED)
    ds = load_windows(csv_path, tokenizer, TEXT_COL)
    collator = MLMDataCollator(tokenizer, mlm_probability=MLM_PROB, max_length=MAX_LEN)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collator)
    frozen = []
    for batch in loader:
        # keep on CPU; we’ll move to device during eval
        frozen.append({k: v.clone() for k, v in batch.items()})
    return frozen

@torch.no_grad()
def eval_perplexity(model_dir: str, tokenizer, frozen_batches) -> float:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to(device).eval()

    # ensure sentinel tokens exist (het5 usually has them already)
    if "<extra_id_0>" not in tokenizer.get_vocab():
        tokenizer.add_tokens([f"<extra_id_{i}>" for i in range(100)])
        model.resize_token_embeddings(len(tokenizer))

    total_loss, total_tokens = 0.0, 0
    for batch in frozen_batches:
        batch_dev = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch_dev)
        valid = (batch_dev["labels"] != -100).sum().item()
        total_loss += out.loss.item() * valid
        total_tokens += valid

    mean_loss = total_loss / max(total_tokens, 1)
    ppl = math.exp(mean_loss)
    return ppl, mean_loss, total_tokens

if __name__ == "__main__":
    # Use the BASE tokenizer to freeze masks for both models
    base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR, use_fast=False)

    # Build frozen masked batches once
    frozen_batches = build_frozen_batches(base_tokenizer, CSV_PATH, batch_size=BATCH_SIZE)

    # Evaluate base and fine-tuned on the SAME masked inputs
    base_ppl, base_loss, base_tokens = eval_perplexity(BASE_MODEL_DIR, base_tokenizer, frozen_batches)
    ft_ppl,   ft_loss,   ft_tokens   = eval_perplexity(FT_MODEL_DIR,   base_tokenizer, frozen_batches)

    print("=== HET5 Perplexity Comparison (same masking, same data) ===")
    print(f"Base mT5    | loss: {base_loss:.4f} | ppl: {base_ppl:.2f} | tokens: {base_tokens}")
    print(f"Fine-tuned   | loss: {ft_loss:.4f}   | ppl: {ft_ppl:.2f}   | tokens: {ft_tokens}")
    print(f"Δloss: {base_loss - ft_loss:.4f} | Δppl: {base_ppl - ft_ppl:.2f}")


  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, u

=== HET5 Perplexity Comparison (same masking, same data) ===
Base mT5    | loss: 4.4979 | ppl: 89.83 | tokens: 597303
Fine-tuned   | loss: 9.0837   | ppl: 8810.11   | tokens: 597303
Δloss: -4.5858 | Δppl: -8720.28


In [4]:
import math, random
import pandas as pd
import torch
from torch.utils.data import DataLoader
from datasets import Dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# ---------------- helpers ----------------
def split_to_windows(text, tokenizer, max_length=256, stride=128):
    toks = tokenizer.tokenize(text)
    if len(toks) <= max_length: return [text]
    out, start = [], 0
    while start < len(toks):
        end = min(start + max_length, len(toks))
        out.append(tokenizer.convert_tokens_to_string(toks[start:end]))
        start += stride
        if end >= len(toks): break
    return out

class MLMDataCollator:
    def __init__(self, tokenizer, mlm_probability=0.15, max_length=256):
        self.tok = tokenizer
        self.p = mlm_probability
        self.max_len = max_length

    def __call__(self, examples):
        inputs, targets = [], []
        for ex in examples:
            inp, tgt = self._span_corrupt(ex["text"])
            inputs.append(inp); targets.append(tgt)
        model_inputs = self.tok(
            inputs, truncation=True, padding="max_length", max_length=self.max_len, return_tensors="pt"
        )
        labels = self.tok(
            targets, truncation=True, padding="max_length", max_length=self.max_len, return_tensors="pt"
        )["input_ids"]
        labels[labels == self.tok.pad_token_id] = -100
        model_inputs["labels"] = labels
        return model_inputs

    def _span_corrupt(self, text):
        toks = self.tok.tokenize(text)
        if not toks: return text, ""
        total_mask = max(1, int(len(toks) * self.p))
        masked = toks.copy(); targets = []; sid, i = 0, 0
        while i < len(toks) and sid * 3 < total_mask:
            if random.random() < 0.15:
                span_len = min(random.randint(1,5), len(toks)-i)
                orig = self.tok.convert_tokens_to_string(toks[i:i+span_len])
                targets.append(f"<extra_id_{sid}> {orig}")
                masked[i:i+span_len] = [f"<extra_id_{sid}>"]
                sid += 1; i += span_len
            else:
                i += 1
        if targets: targets.append(f"<extra_id_{sid}>")
        return self.tok.convert_tokens_to_string(masked), (" ".join(targets) if targets else "")

def prepare_dataset(csv_path, tokenizer, text_col="extracted_gpt_facts", max_length=256, stride=128):
    df = pd.read_csv(csv_path)
    windows = []
    for _, r in df.iterrows():
        txt = str(r.get(text_col, "") or "").strip()
        if not txt: continue
        windows.extend(split_to_windows(txt, tokenizer, max_length, stride))
    return Dataset.from_dict({"text": windows})

# ------------- create masked eval set (one-time) -------------
def create_masked_eval(csv_path, tokenizer, out_dir="masked_eval", seed=42, batch_size=32):
    random.seed(seed)
    dataset = prepare_dataset(csv_path, tokenizer)
    collator = MLMDataCollator(tokenizer, mlm_probability=0.15, max_length=256)

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collator)
    all_input_ids, all_attn, all_labels = [], [], []
    for batch in loader:
        all_input_ids.append(batch["input_ids"])
        all_attn.append(batch["attention_mask"])
        all_labels.append(batch["labels"])

    all_input_ids = torch.cat(all_input_ids, dim=0)
    all_attn     = torch.cat(all_attn,     dim=0)
    all_labels   = torch.cat(all_labels,   dim=0)

    ds_masked = Dataset.from_dict({
        "input_ids": all_input_ids.tolist(),
        "attention_mask": all_attn.tolist(),
        "labels": all_labels.tolist()
    })
    ds_masked.save_to_disk(out_dir)
    print(f"Masked eval dataset saved to {out_dir} with {len(ds_masked)} samples.")

# ------------- perplexity evaluation -------------
@torch.no_grad()
def compute_perplexity_from_masked(model_dir, masked_dir, batch_size=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tok = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to(device).eval()

    if "<extra_id_0>" not in tok.get_vocab():
        tok.add_tokens([f"<extra_id_{i}>" for i in range(100)])
        model.resize_token_embeddings(len(tok))

    ds = load_from_disk(masked_dir)

    def collate(rows):
        return {
            "input_ids": torch.tensor([r["input_ids"] for r in rows], dtype=torch.long),
            "attention_mask": torch.tensor([r["attention_mask"] for r in rows], dtype=torch.long),
            "labels": torch.tensor([r["labels"] for r in rows], dtype=torch.long),
        }

    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)

    total_loss, total_tokens = 0.0, 0
    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        valid = (batch["labels"] != -100).sum().item()
        total_loss += out.loss.item() * valid
        total_tokens += valid

    mean_loss = total_loss / max(total_tokens, 1)
    ppl = math.exp(mean_loss)
    print(f"{model_dir} | loss: {mean_loss:.4f} | ppl: {ppl:.2f} | tokens: {total_tokens}")
    return ppl

if __name__ == "__main__":
    CSV_PATH   = "/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/processed_verdicts_with_gpt.csv"
    MASKED_DIR = "masked_eval"

    # 1) build masked eval once (padding fixed to max_length)
    tok_mask = AutoTokenizer.from_pretrained("imvladikon/het5-base", use_fast=False)
    create_masked_eval(CSV_PATH, tok_mask, out_dir=MASKED_DIR, seed=42, batch_size=32)

    # 2) evaluate models
    compute_perplexity_from_masked("imvladikon/het5-base", MASKED_DIR, batch_size=8)
    compute_perplexity_from_masked("/home/liorkob/M.Sc/thesis/t5/het5-mlm-final", MASKED_DIR, batch_size=8)



Saving the dataset (1/1 shards): 100%|██████████| 10975/10975 [00:00<00:00, 22620.32 examples/s]


Masked eval dataset saved to masked_eval with 10975 samples.
imvladikon/het5-base | loss: 4.4979 | ppl: 89.83 | tokens: 597303
/home/liorkob/M.Sc/thesis/t5/het5-mlm-final | loss: 9.0837 | ppl: 8810.11 | tokens: 597303
