In [2]:
import tensorflow as tf
import numpy as np
from mamba_model import init_model, ModelArgs

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  # Restrict TensorFlow to only use the first GPU
  try:
    tf.config.set_visible_devices(gpus[1], 'GPU')
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
  except RuntimeError as e:
    # Visible devices must be set before GPUs have been initialized
    print(e)

In [3]:
from mamba_model import vocab_size

seq_length = 129
max_samples = 50000

args = ModelArgs(
    model_input_dims=128,
    model_states=32,
    num_layers=12,
    dropout_rate=0.2,
    vocab_size=vocab_size,
    use_lm_head=True,
    num_classes=vocab_size, 
    loss='sparse_categorical_crossentropy'
)
model = init_model(args)
model.optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001)
# Recompiler le modèle avec le nouvel optimizer
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=model.optimizer,
    metrics=['accuracy']
)
model.build(input_shape=(None, args.seq_length))
model.summary()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [4]:
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

dataset = load_dataset("wikitext", "wikitext-103-v1")

'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 7abd0f98-9e68-4d36-ba16-e0addeda8e1d)')' thrown while requesting HEAD https://huggingface.co/bert-base-uncased/resolve/main/tokenizer_config.json
Retrying in 1s [Retry 1/5].


In [5]:
train_ids_input = []
train_ids_labels = []

count = 0

for item in tqdm(dataset['train']):
    if count >= max_samples:
        break

    text = item['text'].strip()

    # Sauter les textes vides
    if len(text) < 50:
        continue

    try:
        # Tokeniser
        tokens = tokenizer.encode(text, truncation=True)

        # au moins seq_length tokens
        if len(tokens) < seq_length:
            continue

        # Créer plusieurs exemples à partir d'un long texte (sliding window)
        for i in range(0, len(tokens) - seq_length, seq_length // 2):  # overlap de 50%
            chunk = tokens[i:i + seq_length]

            if len(chunk) < seq_length:
                # Padding
                chunk = chunk + [tokenizer.pad_token_id] * (seq_length - len(chunk))

            # Input et label décalés
            train_ids_input.append(chunk[:-1])
            train_ids_labels.append(chunk[1:])

            count += 1

            if count >= max_samples:
                break

    except Exception as e:
        continue

train_ids_input = np.array(train_ids_input, dtype=np.int32)
train_ids_labels = np.array(train_ids_labels, dtype=np.int32)

print(f"Préparé {len(train_ids_input)} exemples")
print(f"Input shape: {train_ids_input.shape}")
print(f"Labels shape: {train_ids_labels.shape}")

  7%|▋         | 129203/1801350 [00:16<03:30, 7929.94it/s]


Préparé 50000 exemples
Input shape: (50000, 128)
Labels shape: (50000, 128)


In [6]:
BATCH_SIZE = 16
train_dataset = tf.data.Dataset.from_tensor_slices((train_ids_input, train_ids_labels)).batch(BATCH_SIZE).shuffle(1000)

In [7]:
history = model.fit(train_dataset, epochs=10)

Epoch 1/10
[1m   4/3125[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m8:04:37[0m 9s/step - accuracy: 0.0054 - loss: 15.4089 

KeyboardInterrupt: 

In [9]:
import os
import json

save_dir = "mamba_model_sauvegarde"
os.makedirs(save_dir, exist_ok=True)

# Sauvegarder les poids 
model.save_weights(os.path.join(save_dir, "model_weights.weights.h5"))


# Sauvegarder la configuration
config = {
    'model_input_dims': args.model_input_dims,
    'model_states': args.model_states,
    'num_layers': args.num_layers,
    'dropout_rate': args.dropout_rate,
    'vocab_size': args.vocab_size,
    'num_classes': args.num_classes,
    'seq_length': args.seq_length,
    'conv_kernel_size': args.conv_kernel_size,
    'use_lm_head': args.use_lm_head,
    'loss': str(args.loss),
}

with open(os.path.join(save_dir, "config.json"), "w") as f:
    json.dump(config, f, indent=4)

# Sauvegarder le tokenizer
tokenizer.save_pretrained(save_dir)

('mamba_model_sauvegarde\\tokenizer_config.json',
 'mamba_model_sauvegarde\\special_tokens_map.json',
 'mamba_model_sauvegarde\\vocab.txt',
 'mamba_model_sauvegarde\\added_tokens.json',
 'mamba_model_sauvegarde\\tokenizer.json')