In [None]:
from datasets import load_dataset

# On charge directement le split train/validation/test
dataset = load_dataset("ibm-research/argument_quality_ranking_30k", "argument_quality_ranking")

# Aperçu rapide
print(dataset)
dataset["train"][0]

In [None]:
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Fonction de tokenisation
def preprocess(examples):
    return tokenizer(
        examples["argument"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

# Appliquez la tokenisation sans retirer toutes les colonnes
tokenized = dataset.map(
    preprocess,
    batched=True,
    # ne retirez QUE les colonnes texte brutes inutiles,
    # mais conservez MACE-P
    remove_columns=["argument", "topic", "set", "WA", "stance_WA", "stance_WA_conf"]
)

In [None]:
import tensorflow as tf

def to_tf_dataset(split):
    # Features: input_ids, attention_mask → X ; label: MACE-P → y
    features = {
        "input_ids": tf.constant(split["input_ids"], dtype=tf.int32),
        "attention_mask": tf.constant(split["attention_mask"], dtype=tf.int32),
    }
    labels = tf.constant(split["MACE-P"], dtype=tf.float32)
    return tf.data.Dataset.from_tensor_slices((features, labels))

train_ds = to_tf_dataset(tokenized["train"]).shuffle(10_000).batch(32)
val_ds   = to_tf_dataset(tokenized["validation"]).batch(32)
test_ds  = to_tf_dataset(tokenized["test"]).batch(32)

In [None]:
import tensorflow as tf
from tensorflow.keras.optimizers.schedules import LearningRateSchedule, CosineDecayRestarts

# 1. CosineDecayRestarts pour snapshots et redémarrages
decay_schedule = CosineDecayRestarts(
    initial_learning_rate=3e-5,
    first_decay_steps=1000,   # ajuster selon nb_steps
    t_mul=2.0,
    m_mul=1.0,
    alpha=0.0
)

# 2. Classe WarmUp pour une montée linéaire des LR
class WarmUp(LearningRateSchedule):
    def __init__(self, initial_lr, decay_schedule_fn, warmup_steps):
        super().__init__()
        # Conservez les valeurs pour la configuration
        self.initial_lr = float(initial_lr)
        self.warmup_steps = int(warmup_steps)
        self.decay_fn = decay_schedule_fn

    def __call__(self, step):
        step_float = tf.cast(step, tf.float32)
        warmup_lr = tf.cast(self.initial_lr, tf.float32) * (step_float / tf.cast(self.warmup_steps, tf.float32))
        decay_lr = self.decay_fn(step - self.warmup_steps)
        return tf.where(step_float < self.warmup_steps, warmup_lr, decay_lr)

    def get_config(self):
        # Nécessaire pour la sérialisation du scheduler
        return {
            "initial_lr": self.initial_lr,
            "warmup_steps": self.warmup_steps,
            "decay_schedule_fn": tf.keras.optimizers.schedules.deserialize(
                {
                    "class_name": self.decay_fn.__class__.__name__,
                    "config": self.decay_fn.get_config()
                }
            )
        }

lr_schedule = WarmUp(
    initial_lr=3e-5,
    decay_schedule_fn=decay_schedule,
    warmup_steps=500
)

optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

In [None]:
from tensorflow.keras import layers

class PositionalEncoding(layers.Layer):
    def __init__(self, max_len, d_model):
        super().__init__()
        pos = tf.range(max_len, dtype=tf.float32)[:, tf.newaxis]
        i   = tf.range(d_model, dtype=tf.float32)[tf.newaxis, :]
        angle = pos / tf.pow(10000.0, (2 * (i//2)) / tf.cast(d_model, tf.float32))
        pe = tf.where(tf.cast(i, tf.int32) % 2 == 0, tf.sin(angle), tf.cos(angle))
        self.pe = pe[tf.newaxis, ...]

    def call(self, x):
        return x + self.pe[:, :tf.shape(x)[1], :]

def build_transformer_model(
    vocab_size: int,
    max_len: int = 128,
    d_model: int = 128,
    num_heads: int = 4,
    ff_dim: int = 256,
    num_layers: int = 2,
):
    input_ids = layers.Input(shape=(max_len,), dtype=tf.int32, name="input_ids")
    att_mask  = layers.Input(shape=(max_len,), dtype=tf.int32, name="attention_mask")

    x = layers.Embedding(vocab_size, d_model)(input_ids)
    x = layers.Dropout(0.1)(x)                       # dropout sur embeddings

    x = PositionalEncoding(max_len, d_model)(x)

    for _ in range(num_layers):
        # Multi-head Self-Attention
        attn_output = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=0.1
        )(x, x, attention_mask=att_mask[:, tf.newaxis, tf.newaxis, :])
        attn_output = layers.Dropout(0.1)(attn_output)  # dropout sur attention
        attn_output = layers.LayerNormalization(epsilon=1e-6)(x + attn_output)

        # Feed-forward
        ff = layers.Dense(ff_dim, activation="relu")(attn_output)
        ff = layers.Dropout(0.1)(ff)                   # dropout feed-forward
        ff = layers.Dense(d_model)(ff)
        x = layers.LayerNormalization(epsilon=1e-6)(attn_output + ff)

    x = layers.GlobalAveragePooling1D()(x)
    output = layers.Dense(1, activation="linear", name="mace_p")(x)

    return tf.keras.Model(inputs=[input_ids, att_mask], outputs=output)

vocab_size = tokenizer.vocab_size  # supposant que tokenizer est déjà chargé
model = build_transformer_model(vocab_size)
model.summary()

In [None]:
from tensorflow.keras.callbacks import Callback

class UnfreezeCallback(Callback):
    def __init__(self, freeze_epochs=2):
        super().__init__()
        self.freeze_epochs = freeze_epochs

    def on_train_begin(self, logs=None):
        # geler embeddings et première couche d'attention
        for layer in self.model.layers:
            if isinstance(layer, tf.keras.layers.Embedding) or 'multi_head_attention' in layer.name:
                layer.trainable = False
        print("Couches initiales gelées")

    def on_epoch_end(self, epoch, logs=None):
        if epoch + 1 == self.freeze_epochs:
            for layer in self.model.layers:
                layer.trainable = True
            print(f"Couches dé-gelées à l'époque {epoch+1}")

unfreeze_cb = UnfreezeCallback(freeze_epochs=2)

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

model.compile(
    optimizer=optimizer,                            # scheduler + Adam
    loss="mean_squared_error",
    metrics=[tf.keras.metrics.MeanAbsoluteError(name="MAE")]
)

checkpoint_cb = ModelCheckpoint(
    "best_model.h5",
    monitor="val_MAE",
    mode="min",
    save_best_only=True,
    verbose=1
)
earlystop_cb = EarlyStopping(
    monitor="val_MAE",
    patience=3,
    mode="min",
    restore_best_weights=True,
    verbose=1
)

callbacks = [checkpoint_cb, earlystop_cb, unfreeze_cb]

In [None]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=15,
    callbacks=callbacks
)

In [None]:
result = model.evaluate(test_ds)
print(f"Test MAE: {result[1]:.4f}")