In [1]:
!pip install --quiet flax jax jaxlib optax tensorflow-probability

import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
import optax
import json
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pickle
import random
import seaborn as sns
import os

from functools import partial
from tqdm import tqdm, trange
from astropy.coordinates import SkyCoord
from astropy.table import Table
from astropy import units as u
from collections import Counter
from sklearn.metrics import confusion_matrix,roc_curve, auc, precision_recall_curve, average_precision_score

In [3]:
from google.colab import drive
drive.mount('/content/drive')
!pip install -q ipympl
from google.colab import output
output.enable_custom_widget_manager()

Mounted at /content/drive
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m515.7/515.7 kB[0m [31m32.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m80.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
dossier = "5050"

In [5]:
with open(f"/content/drive/MyDrive/Stage_M2_2025/XV/data/{dossier}/constantes_du_modele.json", 'r') as f:
    config = json.load(f)

# Extraction des valeurs
MAX_SOURCES  = config["MAX_SOURCES"]
MAX_CLUSTERS = config["MAX_CLUSTERS"]
MAX_AGN = config["MAX_AGN"]
VOCAB_SIZE   = config["VOCAB_SIZE"]
PAD_TOKEN    = config["PAD_TOKEN"]
SEP_TOKEN    = config["SEP_TOKEN"]
CLS_TOKEN    = config["CLS_TOKEN"]

print("┌───────────────────────────────┐")
print("│  CONFIGURATION DU MODÈLE      │")
print("├───────────────────────────────┤")
for key, value in config.items():
    print(f"│ {key.ljust(15)}: {str(value).rjust(10)}   │")
print("└───────────────────────────────┘")

with open(f"/content/drive/MyDrive/Stage_M2_2025/XV/results/{dossier}/dim_transformer_D_MODEL128_NUM_HEADS16_NUM_LAYERS4.json", 'r') as f:
    dim_transformer = json.load(f)

BATCH_SIZE   = dim_transformer["BATCH_SIZE"]
D_MODEL      = dim_transformer["D_MODEL"]
NUM_HEADS    = dim_transformer["NUM_HEADS"]
NUM_LAYERS   = dim_transformer["NUM_LAYERS"]

print("┌───────────────────────────────┐")
print("│  CONFIGURATION DU MODÈLE      │")
print("├───────────────────────────────┤")
for key, value in dim_transformer.items():
    print(f"│ {key.ljust(15)}: {str(value).rjust(10)}   │")
print("└───────────────────────────────┘")

┌───────────────────────────────┐
│  CONFIGURATION DU MODÈLE      │
├───────────────────────────────┤
│ VOCAB_SIZE     :       1029   │
│ PAD_TOKEN      :       1024   │
│ SEP_TOKEN      :       1025   │
│ CLS_TOKEN      :       1026   │
│ SEP_AMAS       :       1027   │
│ SEP_AGN        :       1028   │
│ NOMBRE_TOKENS_SPECIAUX:          5   │
│ MAX_SOURCES    :          7   │
│ MAX_CLUSTERS   :          2   │
│ MAX_AGN        :          8   │
└───────────────────────────────┘
┌───────────────────────────────┐
│  CONFIGURATION DU MODÈLE      │
├───────────────────────────────┤
│ BATCH_SIZE     :         64   │
│ D_MODEL        :        128   │
│ NUM_HEADS      :         16   │
│ NUM_LAYERS     :          4   │
└───────────────────────────────┘


In [6]:
SELECTED_COLUMNS_Xamin    = ['EXT_LIKE', 'EXT', 'EXT_RA', 'EXT_DEC', 'EXT_RATE_MOS', 'EXT_RATE_PN', 'PNT_DET_ML', 'PNT_RA', 'PNT_DEC', 'PNT_RATE_MOS', 'PNT_RATE_PN']
SELECTED_COLUMNS_clusters = ['R.A.', 'Dec']
SELECTED_COLUMNS_AGN      = ['ra_mag_gal', 'dec_mag_gal']

use_log_scale_Xamin    = [True, True, False, False, True, True, True, False, False, True, True]
use_log_scale_clusters = [False, False]
use_log_scale_AGN      = [False, False]

columns_dict_Xamin = {column: index for index, column in enumerate(SELECTED_COLUMNS_Xamin)}
columns_dict_clusters = {column: index for index, column in enumerate(SELECTED_COLUMNS_clusters)}
columns_dict_AGN = {column: index for index, column in enumerate(SELECTED_COLUMNS_AGN)}

print(f'Nombre de colonnes SELECTED_COLUMNS_Xamin   : {len(SELECTED_COLUMNS_Xamin)}')
print(f'Nombre de colonnes SELECTED_COLUMNS_clusters: {len(SELECTED_COLUMNS_clusters)}')
print(f'Nombre de colonnes SELECTED_COLUMNS_AGN     : {len(SELECTED_COLUMNS_AGN)}')
print(f'Nombre de colonnes use_log_scale_Xamin      : {len(use_log_scale_Xamin)}')
print(f'Nombre de colonnes use_log_scale_clusters   : {len(use_log_scale_clusters)}')
print(f'Nombre de colonnes use_log_scale_AGN        : {len(use_log_scale_AGN)}')

Nombre de colonnes SELECTED_COLUMNS_Xamin   : 11
Nombre de colonnes SELECTED_COLUMNS_clusters: 2
Nombre de colonnes SELECTED_COLUMNS_AGN     : 2
Nombre de colonnes use_log_scale_Xamin      : 11
Nombre de colonnes use_log_scale_clusters   : 2
Nombre de colonnes use_log_scale_AGN        : 2


In [7]:
# Chargement des données d'entrainement et test
X_train = np.loadtxt(f'/content/drive/MyDrive/Stage_M2_2025/XV/data/{dossier}/X_train.txt', dtype=np.int32)
X_test = np.loadtxt(f'/content/drive/MyDrive/Stage_M2_2025/XV/data/{dossier}/X_test.txt', dtype=np.int32)

# Chargement des paramètres du transformer
with open(f'/content/drive/MyDrive/Stage_M2_2025/XV/results/{dossier}/params_D_MODEL{D_MODEL}_NUM_HEADS{NUM_HEADS}_NUM_LAYERS{NUM_LAYERS}.pkl', 'rb') as f:
    params = pickle.load(f)

In [8]:
print(X_train.shape)
print(X_test.shape)

(92820, 101)
(23868, 101)


In [9]:
class MLP(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.features)(x)
        x = nn.relu(x)
        x = nn.Dense(self.features)(x)
        return x

class TransformerBlock(nn.Module):
    d_model: int
    num_heads: int

    @nn.compact
    def __call__(self, x, mask):
        # Multi-head self-attention
        z = nn.LayerNorm()(x)
        z = nn.MultiHeadDotProductAttention(num_heads=self.num_heads)(z, mask=mask)
        x = x + z
        z = nn.LayerNorm()(x)
        x += MLP(self.d_model)(z)
        return x

class AutoregressiveTransformerModel(nn.Module):
    d_model: int
    num_heads: int
    num_layers: int
    seq_length: int
    vocab_size: int

    @nn.compact
    def __call__(self, x):

        # x shape: [batch_size, seq_length]
        batch_size, seq_len = x.shape

        # Masque causal (1 pour garder, 0 pour bloquer)
        mask = jnp.tril(jnp.ones((seq_len, seq_len)))[None, None, :, :]  # [1, 1, seq_len, seq_len]
        mask = jnp.broadcast_to(mask, (batch_size, self.num_heads, seq_len, seq_len))

        #print("mask shape:", mask.shape)

        # Embedding des tokens discrets
        x = nn.Embed(self.vocab_size, self.d_model)(x) #  [batch, seq_len, d_model]
        #projette chaque token dans un espace de dimension d_model=64.

        # Positional embedding
        positions = jnp.arange(seq_len)
        pos_embed = nn.Embed(self.seq_length, self.d_model)(positions)  # [seq_len, d_model]

        #print("\nx shape:", x.shape)
        #print("pos_embed shape:", pos_embed.shape)
        x += pos_embed[None, :, :]

        # Couches Transformer
        for _ in range(self.num_layers):
            x = TransformerBlock(self.d_model, self.num_heads)(x, mask=mask)

        # Layer norm
        x = nn.LayerNorm()(x)

        # Tête de prédiction
        logits = nn.Dense(self.vocab_size)(x)
        return logits

# Initialisation du modèle avec les mêmes hyperparamètres

model = AutoregressiveTransformerModel(
    d_model=D_MODEL ,
    num_heads=NUM_HEADS ,
    num_layers=NUM_LAYERS ,
    seq_length=X_train.shape[1],
    vocab_size=VOCAB_SIZE
)

In [10]:
print("Dimensions de X_test:", X_test.shape)
initial_tokens = np.array(X_test[0][:len(SELECTED_COLUMNS_Xamin) + 1]).astype(np.int32)
print("Shape des tokens initiaux:", initial_tokens.shape)
max_length = len(SELECTED_COLUMNS_Xamin) * MAX_SOURCES + MAX_CLUSTERS * 2 +MAX_AGN * 2 + 2 # + 2 tokens début/fin
print("max_length =",max_length)

Dimensions de X_test: (23868, 101)
Shape des tokens initiaux: (12,)
max_length = 99


In [11]:
def generate_sequences_batch(params, initial_tokens_batch, max_length, echantillonage=False, temperature=1.0, key=None):
    """Version finale avec gestion robuste des dimensions dynamiques"""
    # Conversion et vérification des inputs
    if isinstance(initial_tokens_batch, (list, np.ndarray)):
        initial_tokens_batch = jnp.array(initial_tokens_batch, dtype=jnp.int32)

    if initial_tokens_batch.ndim == 1:
        initial_tokens_batch = initial_tokens_batch[None, :]

    batch_size, seq_len = initial_tokens_batch.shape

    # Préallocation avec padding
    output_tokens = jnp.full((batch_size, max_length), PAD_TOKEN, dtype=jnp.int32)
    output_tokens = output_tokens.at[:, :seq_len].set(initial_tokens_batch)

    if key is None:
        key = jax.random.PRNGKey(0)

    # Nous allons utiliser une boucle Python standard mais compiler le corps de la boucle
    @partial(jax.jit, static_argnums=(1,))
    def generate_step(carry, i):
        output_tokens, keys = carry

        # Solution alternative pour éviter les slices dynamiques
        # Nous allons utiliser un mask pour sélectionner la séquence actuelle
        mask = jnp.arange(max_length) < i
        masked_tokens = jnp.where(mask[None, :], output_tokens, PAD_TOKEN)

        # Calcul des logits (attention à l'implémentation de votre modèle)
        logits = model.apply(params, masked_tokens)
        next_token_logits = logits[:, i-1, :]  # Nous prenons le dernier token non-pad

        # Génération des nouveaux tokens
        if echantillonage:
            keys, subkey = jax.random.split(keys)
            next_tokens = jax.random.categorical(subkey, next_token_logits / temperature)
        else:
            next_tokens = jnp.argmax(next_token_logits, axis=-1)

        # Mise à jour des tokens
        output_tokens = output_tokens.at[:, i].set(next_tokens)

        # Détection précoce des séquences terminées
        sequence_done = (next_tokens == SEP_TOKEN) | (i >= max_length - 1)

        return (output_tokens, keys), sequence_done

    # Boucle de génération principale
    keys = key
    for i in range(seq_len, max_length):
        (output_tokens, keys), all_done = generate_step((output_tokens, keys), i)
        if jnp.all(all_done):
            break

    return output_tokens

def Pythie_optimized(X, min_idx_gen, max_idx_gen, save_list_of_generated_sequences, name, batch_size=16):
    """Fonction principale optimisée"""
    if not save_list_of_generated_sequences:
        return

    index_end_of_Xamin_part = MAX_SOURCES * len(SELECTED_COLUMNS_Xamin) + 1
    max_idx_gen = min(max_idx_gen, len(X) - 1)

    # Préparation des indices de batch
    indices = jnp.arange(min_idx_gen, max_idx_gen + 1)
    num_batches = int(jnp.ceil(len(indices) / batch_size))

    list_of_generated_sequences = []

    for batch_num in tqdm(range(num_batches), desc="Génération par batch"):
        batch_indices = indices[batch_num*batch_size : (batch_num+1)*batch_size]
        batch = X[batch_indices, :index_end_of_Xamin_part]

        # Conversion et vérification
        batch = jnp.array(batch, dtype=jnp.int32)

        # Génération du batch
        generated = generate_sequences_batch(
            params,
            batch,
            max_length=max_length,
            echantillonage=False
        )

        # Conversion et stockage
        list_of_generated_sequences.extend(np.array(generated))

    # Sauvegarde efficace
    save_path = f"/content/drive/MyDrive/Stage_M2_2025/XV/results/{dossier}/"
    suffix = "full" if max_idx_gen == len(X) - 1 else f"{min_idx_gen}-{max_idx_gen}"
    with open(f"{save_path}generated_seq_by_imperator_{name}_{suffix}.pkl", 'wb') as f:
        pickle.dump(list_of_generated_sequences, f, protocol=pickle.HIGHEST_PROTOCOL)

In [12]:
Pythie_optimized(X_test, 0, 5000, False, "test", batch_size=128) # 25min pour 5000

Génération par batch: 100%|██████████| 40/40 [17:00<00:00, 25.52s/it]


In [20]:
Pythie_optimized(X_test, 5001, 10000, True, "test", batch_size=128) # 25min pour 5000

Génération par batch: 100%|██████████| 40/40 [17:06<00:00, 25.65s/it]


In [14]:
Pythie_optimized(X_test, 10001, 15000, False, "test", batch_size=128) # 25min pour 5000

In [15]:
Pythie_optimized(X_test, 15001, 20000, False, "test", batch_size=128) # 25min pour 5000

In [16]:
Pythie_optimized(X_test, 20001, 25000, False, "test", batch_size=128) # 25min pour 5000

In [17]:
Pythie_optimized(X_test, 25001, 35152, False, "test", batch_size=128) # 25min pour 5000

In [18]:
Pythie_optimized(X_test, 0, 100000, False, "test", batch_size=128)

In [19]:
with open(f'/content/drive/MyDrive/Stage_M2_2025/XV/results/{dossier}/generated_seq_by_imperator_test_0-50.pkl', 'rb') as f:
    gen_seq_on_test = pickle.load(f)

arr = np.array(gen_seq_on_test)

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/Stage_M2_2025/XV/results/5050/generated_seq_by_imperator_test_0-50.pkl'

In [None]:

num=np.random.randint(0,len(arr)-1)

print(X_test[num] - arr[num])

