# Train SHARE Model

## Train SAUTE model with MLM Loss

DO NOT RUN IN LOCAL

In [None]:
!mkdir sources
!curl https://raw.githubusercontent.com/Just1truc/share-qa/refs/heads/main/sources/datasets.py -o sources/datasets.py
!curl https://raw.githubusercontent.com/Just1truc/share-qa/refs/heads/main/sources/saute_model.py -o sources/saute_model.py
!curl https://raw.githubusercontent.com/Just1truc/share-qa/refs/heads/main/sources/saute_config.py -o sources/saute_config.py

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2279  100  2279    0     0   8953      0 --:--:-- --:--:-- --:--:--  8937
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 11160  100 11160    0     0  40972      0 --:--:-- --:--:-- --:--:-- 41029
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1290  100  1290    0     0   5502      0 --:--:-- --:--:-- --:--:--  5512


You might need to restart session to actualize jupiter notebook env here

### Installing dependencies

In [None]:
%pip install flash-attn==1.0.8 --no-build-isolation
%pip install -U transformers
%pip install datasets

Collecting flash-attn==1.0.8
  Downloading flash_attn-1.0.8.tar.gz (2.0 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.5/2.0 MB[0m [31m15.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m32.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ninja (from flash-attn==1.0.8)
  Downloading ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->flash-attn==1.0.8)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->flash-attn==1.0.8)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1

#### Imports

In [None]:
from transformers import Trainer, TrainingArguments
from sources.saute_model import UtteranceEmbedings
from sources.saute_config import SAUTEConfig
from sources.datasets import SAUTEDataset
import torch

#### Load Dataset

In [None]:
train_dataset = SAUTEDataset("train")

#### Load Model

In [None]:
model_config = SAUTEConfig()
model = UtteranceEmbedings(model_config).to("cuda:0")

### Setup MLM Logger

In [None]:
from transformers import TrainerCallback
import torch
import wandb

class WandbPredictionLoggerCallback(TrainerCallback):
    def __init__(self, fixed_batch, tokenizer, log_every_steps=500):
        self.fixed_batch = fixed_batch
        self.tokenizer = tokenizer
        self.log_every_steps = log_every_steps

    def on_step_end(self, args, state, control, model=None, **kwargs):
        if state.global_step % self.log_every_steps == 0:
            self.log_predictions(model, state.global_step)

    def log_predictions(self, model, step):
        device = model.device

        inputs = {
            'input_ids': self.fixed_batch['input_ids'].to(device),
            'attention_mask': self.fixed_batch['attention_mask'].to(device),
            'speaker_names': self.fixed_batch['speaker_names'],  # no move needed
            'labels': self.fixed_batch['labels'].to(device)
        }

        model.eval()
        with torch.no_grad():
            outputs = model(
                input_ids=inputs['input_ids'],
                speaker_names=inputs['speaker_names'],
                attention_mask=inputs['attention_mask'],
                labels=None
            )
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)
        model.train()

        table = wandb.Table(columns=["Step", "Masked Input", "Target Word", "Predicted Word"])

        batch_size, seq_len = preds.shape
        for b in range(batch_size):
            input_ids = inputs['input_ids'][b]
            labels = inputs['labels'][b]
            preds_b = preds[b]

            masked_input_tokens = input_ids.clone()
            for i in range(seq_len):
                if labels[i] == 103:
                    masked_input_tokens[i] = self.tokenizer.mask_token_id

            input_text_with_mask = self.tokenizer.decode(masked_input_tokens, skip_special_tokens=False)

            true_tokens = []
            pred_tokens = []
            for i in range(seq_len):
                if labels[i] != -100:
                    true_tokens.append(self.tokenizer.decode([labels[i]]))
                    pred_tokens.append(self.tokenizer.decode([preds_b[i]]))

            table.add_data(step, input_text_with_mask.replace("[SEP]", "").replace("[PAD]", "").replace("[CLS]", ""), ",".join(true_tokens), ",".join(pred_tokens))

        wandb.log({"MLM Predictions Evolution": table})


#### Training

In [None]:
fixed_batch = train_dataset[0]
print(fixed_batch["input_ids"][2])
print(fixed_batch["labels"][2])
print(train_dataset.tokenizer.convert_tokens_to_ids(train_dataset.tokenizer.mask_token))

#### Init Training necessities

In [None]:
def saute_data_collator(batch):
    return batch[0]

from transformers import BertTokenizerFast

fixed_batch = train_dataset[0]
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Initialize the callback
wandb_logger_callback = WandbPredictionLoggerCallback(
    fixed_batch=fixed_batch,
    tokenizer=tokenizer,
    log_every_steps=50
)

In [None]:
tokenizer_name = "bert-base-uncased"
training_args = TrainingArguments(
    output_dir="cross-speaker-mlm-display-6",
    eval_strategy="no",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    fp16=True,
    max_steps=1506100
    # deepspeed="deepspeed_config.json",  # optional
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=None,
    data_collator=saute_data_collator,
    callbacks=[wandb_logger_callback]
)

trainer.train()

#### Bert Baseline Imports

In [None]:
from transformers import BertConfig, BertForMaskedLM

#### Load Model and dataset

In [None]:
train_dataset = SAUTEDataset(split="train", dialog_format="full")

bert_config = BertConfig(
    vocab_size=30522,
    hidden_size=768,
    num_hidden_layers=6,
    num_attention_heads=12,
    intermediate_size=3072,
    max_position_embeddings=512,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1
)
model = BertForMaskedLM(config=bert_config)

In [None]:
from transformers import BertTokenizerFast

fixed_batch = train_dataset[0]
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Initialize the callback
wandb_logger_callback = WandbPredictionLoggerCallback(
    fixed_batch=fixed_batch,
    tokenizer=tokenizer,
    log_every_steps=50
)

#### Train model

In [None]:
training_args = TrainingArguments(
    output_dir="bert-baseline",
    per_device_train_batch_size=1,
    save_strategy="steps",        # Only save checkpoint
    save_steps=1000,
    logging_steps=50,             # Log loss every 500 steps
    learning_rate=5e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    report_to="wandb",             # wandb tracking
    run_name="baseline-bert-mlm"
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=saute_data_collator,
    callbacks=[wandb_logger_callback]
)

trainer.train()