In [None]:
import pandas as pd
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
)
from transformers.models.bert.modeling_bert import BertOnlyMLMHead

from transformers import (
    BertConfig,
    BertModel,
    AutoConfig,
    PreTrainedModel,
    AutoModel
)
import os
import numpy as np
import re
import pickle
import wandb
import pytorch_lightning as pl
import torch.nn as nn
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
import torch

from sklearn.model_selection import train_test_split
from transformers import BertForMaskedLM
from peft import LoraConfig, get_peft_model

In [None]:
wandb.login()

# Load and prepare data

In [None]:
conversations_path = "arabic_conversations.csv"
messages_path = "arabic_messages.csv"
ARABIC_LEXICON_PATH = "new_arabic_lexicon_17_07.csv"
TOTAL_LEXICON_CATEGORIES = 47
LEXICON_ARABIC_PHRASE_COLUMN = "Arabic Phrase (Linor Translation / Approval)"

MODEL_NAME = "aubmindlab/bert-large-arabertv02"


test_size = 0.3
seed = 358

In [None]:
lexicon_df = pd.read_csv(ARABIC_LEXICON_PATH)
needed_categories = ["Past suicidal history", "Family suicide history", "Suicidal ideation", "Hopelessness", "Deliberate self harm", "Perceived burdensomeness"]
lexicon_df = lexicon_df[lexicon_df.Category.isin(needed_categories)]

In [None]:


messages_df = pd.read_csv(messages_path)
messages_df = messages_df[messages_df['text'].notna()]

conversations_df = pd.read_csv(conversations_path)

labeled_ids = set(conversations_df.engagement_id.unique()) & set(messages_df.engagement_id.unique())

labeled_convs = conversations_df[conversations_df.engagement_id.isin(labeled_ids)]

train_conv, test_conv = train_test_split(
            labeled_convs,
            test_size=test_size,
            random_state=seed,
            stratify=labeled_convs['gsr']
)


# Split messages based on engagement_id from conversations splits
train_ids = set(train_conv['engagement_id'].values)
test_ids = set(test_conv['engagement_id'].values)


unlabeled_msgs = messages_df[~(messages_df['engagement_id'].isin(labeled_ids))]

train_msgs = messages_df[messages_df['engagement_id'].isin(train_ids)]
test_msgs = messages_df[messages_df['engagement_id'].isin(test_ids)]

pretrain_messages_df = pd.concat([unlabeled_msgs, train_msgs])

print(f"train convs: {len(train_conv)}, train msgs: {len(train_msgs)}")
print(f"test convs: {len(test_conv)}, test msgs: {len(test_msgs)}")
print(f"pretraining (unlabeled + train ) convs: {len(pretrain_messages_df.engagement_id.unique())}, pretraining (unlabeled + train ) msgs: {len(pretrain_messages_df)}")


# Pre-process

In [None]:
pretrain_messages_df = pretrain_messages_df.groupby('engagement_id')['text'].apply(lambda texts: ' [SEP] '.join(texts)).reset_index(name='text')
test_msgs = test_msgs.groupby('engagement_id')['text'].apply(lambda texts: ' [SEP] '.join(texts)).reset_index(name='text')


def normalize_arabic(text):
    # unify alefs, yeh, teh marbuta; remove diacritics & tatweel
    text = re.sub(r"[إأآٱا]", "ا", text)
    text = re.sub(r"[يى]", "ي", text)
    text = re.sub(r"ة", "ه", text)  
    # text = re.sub(r"[ًٌٍَُِّْٰـ]", "", text)
    return text


def text2lexiconFreqVec(text: str) -> np.ndarray:
    lexicon_df['counts'] = lexicon_df[LEXICON_ARABIC_PHRASE_COLUMN].apply(lambda phrase: text.count(phrase))
    vec = lexicon_df.groupby("Category")['counts'].sum().values

    return vec


    
pretrain_messages_df['text'] = pretrain_messages_df['text'].apply(normalize_arabic)
test_msgs['text'] = test_msgs['text'].apply(normalize_arabic)

lexicon_df[LEXICON_ARABIC_PHRASE_COLUMN] = lexicon_df[LEXICON_ARABIC_PHRASE_COLUMN].apply(normalize_arabic)

pretrain_messages_df['lexiconVec'] = pretrain_messages_df['text'].apply(text2lexiconFreqVec)
test_msgs['lexiconVec'] = test_msgs['text'].apply(text2lexiconFreqVec)


raw_train = Dataset.from_pandas(pretrain_messages_df[['text', "lexiconVec"]])  # keeps the "text" column
raw_test = Dataset.from_pandas(test_msgs[['text', "lexiconVec"]])  # keeps the "text" column

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_fn(batch):
    enc = tokenizer(batch["text"], truncation=True, max_length=512, padding=False)
    enc["lexiconVec"] = batch["lexiconVec"]   # keeps your lexicon vectors
    return enc


tokenized_ds_train = raw_train.map(tokenize_fn, batched=True, remove_columns=["text"])
tokenized_ds_test = raw_test.map(tokenize_fn, batched=True, remove_columns=["text"])

# Define model class

In [None]:
class LitBERTClassifier(pl.LightningModule):
    def __init__(self, model_name, config: BertConfig, reg_dim: int = len(needed_categories), alpha: float = 1.0, beta: float = 1.0, learning_rate=2e-5):
        super().__init__()
        self.save_hyperparameters()
        
        self.bert_mlm = BertForMaskedLM.from_pretrained(model_name)
        
        self.reg_head = nn.Sequential( # regression head
            nn.Dropout(0.4),
            nn.Linear(config.hidden_size, reg_dim)
        )
        
        self.alpha = alpha  # weight for MLM loss
        self.beta  = beta   # weight for regression (MSE) loss
        self.mse   = nn.MSELoss()


    def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None, lexiconVec=None):
        outputs = self.bert_mlm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=labels,
            return_dict=True,
            output_hidden_states=True
        )
        
        pooled_out = outputs.hidden_states[-1][:, 0, :]  # CLS token     
        lexiconVec = lexiconVec.to(torch.float32)
        
        # Heads
        mlm_logits = outputs.logits
        reg_pred   = self.reg_head(pooled_out)        # (B, 47)

        loss = None
        mlm_loss = outputs.loss 
        reg_loss = None

        if mlm_loss is None: # prevent logger errors
            mlm_loss=0
        
        if lexiconVec is not None:
            reg_loss = self.mse(reg_pred, lexiconVec)
        else:
            print("lexiconVec is None, what is going on?!")


        if (mlm_loss is not None) or (reg_loss is not None):
            loss = (self.alpha * mlm_loss if mlm_loss is not None else 0.0) + \
                   (self.beta  * reg_loss if reg_loss is not None else 0.0)

        
        return {
            "loss": loss,
            "mlm_loss": mlm_loss,
            "reg_loss": reg_loss,
            "logits": mlm_logits,     # for completeness
            "reg_pred": reg_pred,
        }
        
    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log("train_loss", outputs['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_mlm_loss", outputs['mlm_loss'], on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log("train_reg_loss", outputs['reg_loss'], on_step=False, on_epoch=True, prog_bar=False, logger=True)

        return outputs['loss']

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log("val_loss", outputs['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_mlm_loss", outputs['mlm_loss'], on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log("val_reg_loss", outputs['reg_loss'], on_step=False, on_epoch=True, prog_bar=False, logger=True)
        
        return outputs['loss']

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)

# Train

In [None]:
from transformers import AutoConfig
from transformers import DataCollatorForLanguageModeling, TrainingArguments, Trainer

base_config = AutoConfig.from_pretrained(MODEL_NAME)
model = LitBERTClassifier(
    MODEL_NAME,
    config=base_config,
    reg_dim=len(needed_categories),
    alpha=1.0,   # weight for MLM
    beta=1.0,    # weight for regression
    learning_rate=7e-6
)


# MLM on (set mlm_probability as needed)
collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)



def to_float32(batch):
    # ensure reg_targets is float32 tensor
    import numpy as np
    batch["lexiconVec"] = [np.asarray(x, dtype=np.float32) for x in batch["lexiconVec"]]
    return batch


train_dataset=tokenized_ds_train.map(to_float32, batched=True)
eval_dataset=tokenized_ds_test.map(to_float32, batched=True)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collator)
val_loader = DataLoader(eval_dataset, batch_size=batch_size, collate_fn=collator)


exp_name = "mlm & reg arabertv02 large (lr=7e-6)"
wandb_logger = WandbLogger(project="gsr pred", name=exp_name,
                               group="FULL CONV", resume=False)

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="cuda",
    precision=16,              # enables mixed precision
    logger=wandb_logger,
)


trainer.fit(model, train_loader, val_loader)

wandb.finish()

save_dir_name = f"pretrain/{exp_name}"
trainer.save_checkpoint(f"./{save_dir_name}/marbertv2-pretrain-pl.ckpt")
model.bert_mlm.save_pretrained(f"./{save_dir_name}/bert_pretrained")

tokenizer.save_pretrained(save_dir_name)