# Train SHARE Model

## Train SAUTE model with MLM Loss

DO NOT RUN IN LOCAL

In [1]:
!mkdir sources
!curl https://raw.githubusercontent.com/Just1truc/share-qa/refs/heads/main/sources/datasets.py -o sources/datasets.py
!curl https://raw.githubusercontent.com/Just1truc/share-qa/refs/heads/main/sources/saute_model.py -o sources/saute_model.py
!curl https://raw.githubusercontent.com/Just1truc/share-qa/refs/heads/main/sources/saute_config.py -o sources/saute_config.py
!curl https://raw.githubusercontent.com/Just1truc/share-qa/refs/heads/main/sources/mlm_logger.py -o sources/mlm_logger.py

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5117  100  5117    0     0  11185      0 --:--:-- --:--:-- --:--:-- 11196
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 17655  100 17655    0     0  37186      0 --:--:-- --:--:-- --:--:-- 37168
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1290  100  1290    0     0   2853      0 --:--:-- --:--:-- --:--:--  2860
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2301  100  2301    0     0   5701      0 --:--:-- --:--:-- --:--:--  5695


You might need to restart session to actualize jupiter notebook env here

### Installing dependencies

In [2]:
%pip install flash-attn==1.0.8 --no-build-isolation
%pip install -U transformers
%pip install datasets
%pip install torchinfo

Collecting flash-attn==1.0.8
  Downloading flash_attn-1.0.8.tar.gz (2.0 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.0/2.0 MB[0m [31m83.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m53.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ninja (from flash-attn==1.0.8)
  Downloading ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->flash-attn==1.0.8)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->flash-attn==1.0.8)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Co

#### Imports

In [2]:
import torch

from transformers         import Trainer, TrainingArguments, BertConfig, BertForMaskedLM, BertTokenizerFast
from sources.saute_model  import UtteranceEmbedings
from sources.saute_config import SAUTEConfig
from sources.datasets     import SAUTEDataset
from sources.mlm_logger   import WandbPredictionLoggerCallback
from torchinfo            import summary

#### Accuracy Metric

In [3]:
def compute_masked_accuracy(logits, labels):
    # logits: [batch_size, seq_len, vocab_size]
    # labels: [batch_size, seq_len]

    preds = torch.argmax(logits, dim=-1)  # [batch_size, seq_len]

    # Only consider masked positions (labels != -100)
    mask = labels != -100

    # Count correct predictions
    correct = (preds == labels) & mask
    accuracy = correct.sum().float() / mask.sum()

    return accuracy.item()

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    logits = torch.tensor(logits)
    labels = torch.tensor(labels)
    acc = compute_masked_accuracy(logits, labels)
    return {"masked_accuracy": acc}

#### Load Dataset

In [3]:
train_dataset = SAUTEDataset(split="train", dialog_format="edu")
eval_dataset = SAUTEDataset(split="test", dialog_format="edu")

from torch.utils.data import Subset
import random

subset_size = 25
indices = random.sample(range(len(eval_dataset)), subset_size)
test_dataset = Subset(eval_dataset, indices)

#### Load Model

In [4]:
model_config = SAUTEConfig(
    num_attention_heads = 12,
    num_hidden_layers   = 2
)
model = UtteranceEmbedings(model_config).to("cuda:0")

#### Training

In [5]:
fixed_batch = train_dataset[0]

#### Model Statistics

In [6]:

class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, speaker_names=fixed_batch["speaker_names"])

wrapped_model = ModelWrapper(model).to("cuda:0")

summary(
    wrapped_model,
    input_data=(fixed_batch["input_ids"].to("cuda:0"), fixed_batch["attention_mask"].to("cuda:0")),
    col_names=["input_size", "output_size", "num_params", "mult_adds"],
    depth=3,
    verbose=1
)

Layer (type:depth-idx)                                            Input Shape               Output Shape              Param #                   Mult-Adds
ModelWrapper                                                      [6, 128]                  [6, 128, 30522]           --                        --
├─UtteranceEmbedings: 1-1                                         --                        [6, 128, 30522]           --                        --
│    └─HSauteUnit: 2-1                                            --                        --                        --                        --
│    │    └─Embedding: 3-1                                        [6, 128]                  [6, 128, 768]             23,440,896                140,645,376
│    │    └─Embedding: 3-2                                        [6, 128]                  [6, 128, 768]             393,216                   2,359,296
│    │    └─ModuleList: 3-3                                       --                        -- 

Layer (type:depth-idx)                                            Input Shape               Output Shape              Param #                   Mult-Adds
ModelWrapper                                                      [6, 128]                  [6, 128, 30522]           --                        --
├─UtteranceEmbedings: 1-1                                         --                        [6, 128, 30522]           --                        --
│    └─HSauteUnit: 2-1                                            --                        --                        --                        --
│    │    └─Embedding: 3-1                                        [6, 128]                  [6, 128, 768]             23,440,896                140,645,376
│    │    └─Embedding: 3-2                                        [6, 128]                  [6, 128, 768]             393,216                   2,359,296
│    │    └─ModuleList: 3-3                                       --                        -- 

#### Init Training necessities

In [1]:
fixed_batch = train_dataset[1]
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Initialize the callback
wandb_logger_callback = WandbPredictionLoggerCallback(
    fixed_batch=fixed_batch,
    tokenizer=tokenizer,
    log_every_steps=50
)

NameError: name 'train_dataset' is not defined

In [None]:
tokenizer_name = "bert-base-uncased"
training_args = TrainingArguments(
    output_dir="h-saute-mlm-76m-0.0.1",
    eval_strategy="steps",
    eval_steps=150,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=150,
    fp16=True,
    max_steps=1506100
    # deepspeed="deepspeed_config.json",  # optional
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=None,
    data_collator=lambda batch: batch[0],
    callbacks=[wandb_logger_callback],
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics
)

trainer.train()

  trainer = Trainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjustinduc[0m ([33mjustinduc-epitech[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Masked Accuracy
150,5.8348,6.220988,0.069164
300,5.372,6.107636,0.075501
450,5.5118,5.931244,0.110283
600,5.6702,5.749789,0.107784
750,5.3456,5.531967,0.1312
900,5.8897,5.668138,0.112462
1050,5.6505,5.646456,0.131356
1200,4.9095,5.522082,0.127055
1350,5.3103,5.577678,0.130758
1500,5.1813,5.491559,0.130045


### Bert Baseline

#### Load Model and dataset

In [4]:
train_dataset = SAUTEDataset(split="train", dialog_format="full")
eval_dataset = SAUTEDataset(split="test", dialog_format="full")

from torch.utils.data import Subset
import random

subset_size = 50
indices = random.sample(range(len(eval_dataset)), subset_size)
test_dataset = Subset(eval_dataset, indices)

bert_config = BertConfig(
    vocab_size=30522,
    hidden_size=768,
    num_hidden_layers=6,
    num_attention_heads=12,
    intermediate_size=3072,
    max_position_embeddings=512,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1
)
model = BertForMaskedLM(config=bert_config)

In [5]:
fixed_batch = train_dataset[1]
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

wandb_logger_callback = WandbPredictionLoggerCallback(
    fixed_batch=fixed_batch,
    tokenizer=tokenizer,
    log_every_steps=10
)

#### Model Statistics

In [6]:
summary(
    model,
    input_data=(fixed_batch["input_ids"], fixed_batch["attention_mask"]),
    col_names=["input_size", "output_size", "num_params", "mult_adds"],
    depth=3,
    verbose=1
)

Layer (type:depth-idx)                                       Input Shape               Output Shape              Param #                   Mult-Adds
BertForMaskedLM                                              [1, 128]                  [1, 128, 30522]           --                        --
├─BertModel: 1-1                                             [1, 128]                  [1, 128, 768]             --                        --
│    └─BertEmbeddings: 2-1                                   --                        [1, 128, 768]             --                        --
│    │    └─Embedding: 3-1                                   [1, 128]                  [1, 128, 768]             23,440,896                23,440,896
│    │    └─Embedding: 3-2                                   [1, 128]                  [1, 128, 768]             1,536                     1,536
│    │    └─Embedding: 3-3                                   [1, 128]                  [1, 128, 768]             393,216          

Layer (type:depth-idx)                                       Input Shape               Output Shape              Param #                   Mult-Adds
BertForMaskedLM                                              [1, 128]                  [1, 128, 30522]           --                        --
├─BertModel: 1-1                                             [1, 128]                  [1, 128, 768]             --                        --
│    └─BertEmbeddings: 2-1                                   --                        [1, 128, 768]             --                        --
│    │    └─Embedding: 3-1                                   [1, 128]                  [1, 128, 768]             23,440,896                23,440,896
│    │    └─Embedding: 3-2                                   [1, 128]                  [1, 128, 768]             1,536                     1,536
│    │    └─Embedding: 3-3                                   [1, 128]                  [1, 128, 768]             393,216          

#### Train model

In [None]:
training_args = TrainingArguments(
    output_dir="bert-baseline-mlm-90m",
    per_device_train_batch_size=1,
    # save_strategy="steps",
    # save_steps=1000,
    eval_strategy="steps",
    eval_steps=150,
    logging_steps=150,
    learning_rate=5e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    report_to="wandb",
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=lambda batch: batch[0],
    callbacks=[wandb_logger_callback],
    compute_metrics=compute_metrics
)

trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjustinduc[0m ([33mjustinduc-epitech[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Masked Accuracy
150,7.283,6.017062,0.045455
300,5.9119,5.686316,0.115385
450,5.6468,5.427762,0.114035
600,5.7985,5.42051,0.144
750,5.5983,5.362011,0.117241
900,5.6147,5.330958,0.190141
1050,5.5305,5.254026,0.104839
1200,5.5764,5.606144,0.126866
1350,5.4112,5.772059,0.128378
1500,5.5253,5.278173,0.172414


Step,Training Loss,Validation Loss,Masked Accuracy
150,7.283,6.017062,0.045455
300,5.9119,5.686316,0.115385
450,5.6468,5.427762,0.114035
600,5.7985,5.42051,0.144
750,5.5983,5.362011,0.117241
900,5.6147,5.330958,0.190141
1050,5.5305,5.254026,0.104839
1200,5.5764,5.606144,0.126866
1350,5.4112,5.772059,0.128378
1500,5.5253,5.278173,0.172414
