In [None]:
# ============================================================
# 0. V√âRIFICATION GPU ET INSTALLATION
# ============================================================
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
print("GPU disponibles:", tf.config.list_physical_devices('GPU'))

# Installation des d√©pendances
!pip install mlflow -q
print("\n‚úì Setup termin√©!")

In [None]:
# ============================================================
# 1. IMPORTS ET CONFIGURATION
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import os
from datetime import datetime

# Cr√©er dossier outputs
os.makedirs('outputs', exist_ok=True)
os.makedirs('figures', exist_ok=True)

print("‚úì Configuration termin√©e")

## 2. Architecture MA-TAP (Innovation)

In [None]:
# ============================================================
# 2. CELLULE MA-TAP (INNOVATION PRINCIPALE)
# ============================================================

class MATAPCell(layers.Layer):
    """
    Memory-Augmented Time-Aware Path Cell.
    
    Hybride GRU + M√©moire √âpisodique + Attention pour combattre le Latent Drift.
    """
    
    def __init__(self, latent_dim, memory_size=10, num_heads=4, dropout_rate=0.1, **kwargs):
        super(MATAPCell, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        self.memory_size = memory_size
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        
        # Composants
        self.input_proj = layers.Dense(latent_dim, name="input_projection")
        self.gru = layers.GRUCell(latent_dim, name="gru_dynamics")
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, 
            key_dim=latent_dim // num_heads,
            value_dim=latent_dim // num_heads,
            dropout=dropout_rate,
            name="retrospective_attention"
        )
        self.layer_norm_attn = layers.LayerNormalization()
        self.layer_norm_out = layers.LayerNormalization()
        self.context_proj = layers.Dense(latent_dim, activation='tanh')
        self.gate_dense = layers.Dense(latent_dim, activation='sigmoid')
        self.memory_write_proj = layers.Dense(latent_dim)
        
        self.state_size = [latent_dim, memory_size * latent_dim]
        self.output_size = latent_dim

    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
        if batch_size is None:
            batch_size = tf.shape(inputs)[0]
        if dtype is None:
            dtype = tf.float32
        init_h = tf.zeros((batch_size, self.latent_dim), dtype=dtype)
        init_memory = tf.zeros((batch_size, self.memory_size * self.latent_dim), dtype=dtype)
        return [init_h, init_memory]

    def call(self, inputs, states, training=None):
        h_prev, memory_flat = states
        batch_size = tf.shape(inputs)[0]
        
        z_t = self.input_proj(inputs)
        memory = tf.reshape(memory_flat, (batch_size, self.memory_size, self.latent_dim))
        
        # Dynamique locale (GRU)
        gru_out, [h_candidate] = self.gru(z_t, [h_prev], training=training)
        
        # Attention r√©trospective
        query = tf.expand_dims(gru_out, axis=1)
        context = self.attention(query=query, value=memory, key=memory, training=training)
        context = tf.squeeze(context, axis=1)
        context = self.layer_norm_attn(context)
        context_proj = self.context_proj(context)
        
        # Fusion adaptative
        gate_input = tf.concat([gru_out, context_proj], axis=-1)
        alpha = self.gate_dense(gate_input)
        h_corrected = (1.0 - alpha) * gru_out + alpha * context_proj
        h_corrected = self.layer_norm_out(h_corrected)
        
        # Mise √† jour m√©moire FIFO
        new_entry = self.memory_write_proj(z_t)
        new_entry = tf.expand_dims(new_entry, axis=1)
        new_memory = tf.concat([memory[:, 1:, :], new_entry], axis=1)
        new_memory_flat = tf.reshape(new_memory, (batch_size, self.memory_size * self.latent_dim))
        
        return h_corrected, [h_corrected, new_memory_flat]


class VanillaGRUCell(layers.Layer):
    """Cellule GRU Baseline (sans m√©moire) pour ablation."""
    
    def __init__(self, latent_dim, **kwargs):
        super(VanillaGRUCell, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        self.state_size = [latent_dim]
        self.output_size = latent_dim
        self._input_proj = None
        self._gru = None
        self._layer_norm = None

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self._input_proj = layers.Dense(self.latent_dim, name="baseline_input_proj")
        self._input_proj.build((None, input_dim))
        self._gru = layers.GRUCell(self.latent_dim, name="baseline_gru")
        self._gru.build((None, self.latent_dim))
        self._layer_norm = layers.LayerNormalization(name="baseline_ln")
        self._layer_norm.build((None, self.latent_dim))
        super().build(input_shape)

    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
        if batch_size is None:
            batch_size = tf.shape(inputs)[0]
        if dtype is None:
            dtype = tf.float32
        return [tf.zeros((batch_size, self.latent_dim), dtype=dtype)]

    def call(self, inputs, states, training=None):
        h_prev = states[0]
        z_t = self._input_proj(inputs)
        gru_out, [h_new] = self._gru(z_t, [h_prev], training=training)
        h_new = self._layer_norm(h_new)
        return h_new, [h_new]

print("‚úì Cellules MA-TAP et Baseline d√©finies")

In [None]:
# ============================================================
# 3. ENCODEUR/D√âCODEUR SPATIAUX (AM√âLIOR√âS)
# ============================================================

class SpatialEncoder(keras.Model):
    """Frame (64x64x1) -> Vecteur Latent - Architecture am√©lior√©e"""
    def __init__(self, latent_dim=128):
        super(SpatialEncoder, self).__init__()
        # Plus de filtres pour meilleure capacit√©
        self.conv1 = layers.Conv2D(64, 4, strides=2, padding="same")  # 32->64
        self.bn1 = layers.BatchNormalization()
        self.conv2 = layers.Conv2D(128, 4, strides=2, padding="same")  # 64->128
        self.bn2 = layers.BatchNormalization()
        self.conv3 = layers.Conv2D(256, 4, strides=2, padding="same")  # 128->256
        self.bn3 = layers.BatchNormalization()
        self.conv4 = layers.Conv2D(256, 3, strides=1, padding="same")  # Couche suppl√©mentaire
        self.bn4 = layers.BatchNormalization()
        self.flatten = layers.Flatten()
        self.fc1 = layers.Dense(512, activation="relu")  # 256->512
        self.dropout = layers.Dropout(0.3)  # Plus de dropout
        self.fc_out = layers.Dense(latent_dim)

    def call(self, x, training=False):
        x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training), 0.2)
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training), 0.2)
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training), 0.2)
        x = tf.nn.leaky_relu(self.bn4(self.conv4(x), training=training), 0.2)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.dropout(x, training=training)
        return self.fc_out(x)


class SpatialDecoder(keras.Model):
    """Vecteur Latent -> Frame (64x64x1) - Architecture am√©lior√©e"""
    def __init__(self, latent_dim=128):
        super(SpatialDecoder, self).__init__()
        self.fc1 = layers.Dense(512, activation="relu")  # 256->512
        self.fc2 = layers.Dense(8 * 8 * 256, activation="relu")  # Plus de canaux
        self.reshape = layers.Reshape((8, 8, 256))
        self.deconv1 = layers.Conv2DTranspose(256, 4, strides=2, padding="same")
        self.bn1 = layers.BatchNormalization()
        self.deconv2 = layers.Conv2DTranspose(128, 4, strides=2, padding="same")
        self.bn2 = layers.BatchNormalization()
        self.deconv3 = layers.Conv2DTranspose(64, 4, strides=2, padding="same")
        self.bn3 = layers.BatchNormalization()
        # Couche de raffinement finale
        self.refine = layers.Conv2D(32, 3, padding="same")
        self.bn4 = layers.BatchNormalization()
        self.output_conv = layers.Conv2D(1, 3, padding="same", activation="sigmoid")

    def call(self, z, training=False):
        x = self.fc1(z)
        x = self.fc2(x)
        x = self.reshape(x)
        x = tf.nn.relu(self.bn1(self.deconv1(x), training=training))
        x = tf.nn.relu(self.bn2(self.deconv2(x), training=training))
        x = tf.nn.relu(self.bn3(self.deconv3(x), training=training))
        x = tf.nn.relu(self.bn4(self.refine(x), training=training))
        return self.output_conv(x)

print("‚úì Encodeur/D√©codeur AM√âLIOR√âS d√©finis")
print("  - Encoder: 64->128->256->256 filtres, Leaky ReLU")
print("  - Decoder: 256->128->64->32 filtres + refinement layer")

In [None]:
# ============================================================
# 4. MOD√àLES COMPLETS (MA-TAP + BASELINE)
# ============================================================

class MATAPModel(keras.Model):
    """Mod√®le MA-TAP complet avec m√©moire augment√©e."""
    
    def __init__(self, latent_dim=64, memory_size=10, num_heads=4, dropout_rate=0.1):
        super(MATAPModel, self).__init__()
        self.latent_dim = latent_dim
        self.memory_size = memory_size
        
        self.encoder = SpatialEncoder(latent_dim)
        self.decoder = SpatialDecoder(latent_dim)
        self.matap_cell = MATAPCell(latent_dim, memory_size, num_heads, dropout_rate)
        self.rnn = layers.RNN(self.matap_cell, return_sequences=True, return_state=True)
        self.predictor = keras.Sequential([
            layers.Dense(latent_dim * 2, activation='relu'),
            layers.Dropout(dropout_rate),
            layers.Dense(latent_dim)
        ])
    
    def encode_sequence(self, frames, training=False):
        B, T = tf.shape(frames)[0], tf.shape(frames)[1]
        flat = tf.reshape(frames, (B * T, 64, 64, 1))
        z_flat = self.encoder(flat, training=training)
        return tf.reshape(z_flat, (B, T, self.latent_dim))
    
    def decode_sequence(self, z_seq, training=False):
        B, T = tf.shape(z_seq)[0], tf.shape(z_seq)[1]
        z_flat = tf.reshape(z_seq, (B * T, self.latent_dim))
        frames_flat = self.decoder(z_flat, training=training)
        return tf.reshape(frames_flat, (B, T, 64, 64, 1))
    
    def call(self, inputs, training=False):
        B = tf.shape(inputs)[0]
        z_seq = self.encode_sequence(inputs, training=training)
        initial_states = self.matap_cell.get_initial_state(batch_size=B)
        h_seq, final_h, final_memory = self.rnn(z_seq, initial_state=initial_states, training=training)
        z_pred = self.predictor(h_seq)
        reconstructed = self.decode_sequence(z_pred, training=training)
        return reconstructed, z_seq, z_pred, [final_h, final_memory]


class BaselineTAPModel(keras.Model):
    """Mod√®le Baseline GRU (sans m√©moire) pour ablation."""
    
    def __init__(self, latent_dim=64, dropout_rate=0.1):
        super(BaselineTAPModel, self).__init__()
        self.latent_dim = latent_dim
        
        self.encoder = SpatialEncoder(latent_dim)
        self.decoder = SpatialDecoder(latent_dim)
        self.gru_cell = VanillaGRUCell(latent_dim)
        self.rnn = layers.RNN(self.gru_cell, return_sequences=True, return_state=True)
        self.predictor = keras.Sequential([
            layers.Dense(latent_dim * 2, activation='relu'),
            layers.Dropout(dropout_rate),
            layers.Dense(latent_dim)
        ])
    
    def encode_sequence(self, frames, training=False):
        B, T = tf.shape(frames)[0], tf.shape(frames)[1]
        flat = tf.reshape(frames, (B * T, 64, 64, 1))
        z_flat = self.encoder(flat, training=training)
        return tf.reshape(z_flat, (B, T, self.latent_dim))
    
    def decode_sequence(self, z_seq, training=False):
        B, T = tf.shape(z_seq)[0], tf.shape(z_seq)[1]
        z_flat = tf.reshape(z_seq, (B * T, self.latent_dim))
        frames_flat = self.decoder(z_flat, training=training)
        return tf.reshape(frames_flat, (B, T, 64, 64, 1))
    
    def call(self, inputs, training=False):
        B = tf.shape(inputs)[0]
        z_seq = self.encode_sequence(inputs, training=training)
        initial_states = self.gru_cell.get_initial_state(batch_size=B)
        h_seq, final_h = self.rnn(z_seq, initial_state=initial_states, training=training)
        z_pred = self.predictor(h_seq)
        reconstructed = self.decode_sequence(z_pred, training=training)
        return reconstructed, z_seq, z_pred, [final_h]

print("‚úì Mod√®les MA-TAP et Baseline d√©finis")

In [None]:
# ============================================================
# 5. G√âN√âRATEUR MOVING MNIST
# ============================================================

class MovingMNISTGenerator:
    """G√©n√®re des s√©quences Moving MNIST."""
    
    def __init__(self, image_size=64, digit_size=28, num_digits=2, seq_length=20):
        self.image_size = image_size
        self.digit_size = digit_size
        self.num_digits = num_digits
        self.seq_length = seq_length
        
        # Charger MNIST
        (x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
        self.mnist_train = x_train.astype(np.float32) / 255.0
        self.mnist_test = x_test.astype(np.float32) / 255.0
        print(f"[MovingMNIST] Loaded {len(self.mnist_train)} train, {len(self.mnist_test)} test digits")
    
    def _get_random_digit(self, use_test=False):
        data = self.mnist_test if use_test else self.mnist_train
        return data[np.random.randint(len(data))]
    
    def _generate_trajectory(self, seq_length):
        x = np.random.randint(0, self.image_size - self.digit_size)
        y = np.random.randint(0, self.image_size - self.digit_size)
        speed = np.random.uniform(2, 5)
        angle = np.random.uniform(0, 2 * np.pi)
        vx, vy = speed * np.cos(angle), speed * np.sin(angle)
        
        positions = []
        for _ in range(seq_length):
            positions.append((int(x), int(y)))
            x, y = x + vx, y + vy
            if x < 0 or x > self.image_size - self.digit_size:
                vx = -vx
                x = np.clip(x, 0, self.image_size - self.digit_size)
            if y < 0 or y > self.image_size - self.digit_size:
                vy = -vy
                y = np.clip(y, 0, self.image_size - self.digit_size)
        return positions
    
    def generate_sequence(self, use_test=False):
        seq = np.zeros((self.seq_length, self.image_size, self.image_size, 1), dtype=np.float32)
        for _ in range(self.num_digits):
            digit = self._get_random_digit(use_test)
            traj = self._generate_trajectory(self.seq_length)
            for t, (x, y) in enumerate(traj):
                x_end = min(x + self.digit_size, self.image_size)
                y_end = min(y + self.digit_size, self.image_size)
                seq[t, y:y_end, x:x_end, 0] = np.clip(
                    seq[t, y:y_end, x:x_end, 0] + digit[:y_end-y, :x_end-x], 0, 1
                )
        return seq
    
    def generate_batch(self, batch_size, use_test=False):
        batch = np.zeros((batch_size, self.seq_length, self.image_size, self.image_size, 1), dtype=np.float32)
        for i in range(batch_size):
            batch[i] = self.generate_sequence(use_test)
        return batch

# Initialiser le g√©n√©rateur
data_gen = MovingMNISTGenerator(seq_length=20, num_digits=2)
print("‚úì G√©n√©rateur Moving MNIST pr√™t")

## 6. Entra√Ænement (Version AM√âLIOR√âE)

### Changements par rapport √† la version pr√©c√©dente :

| Param√®tre | Avant | Apr√®s | Raison |
|-----------|-------|-------|--------|
| `latent_dim` | 64 | **128** | Plus de capacit√© pour repr√©senter les digits |
| `epochs` | 50 | **150** | Convergence compl√®te |
| `learning_rate` | 1e-3 | **5e-4** | Gradients plus stables |
| `batch_size` | 32 | **16** | Meilleure g√©n√©ralisation |
| `latent_loss_weight` | 0.1 | **0.01** | Focus sur reconstruction |
| `encoder filters` | 32-64-128 | **64-128-256-256** | Plus de capacit√© |
| `LR schedule` | Constant | **Cosine Decay** | Meilleure convergence |
| `Loss` | BCE | **MSE + 0.5*BCE** | Meilleurs gradients pour pixels |

‚ö†Ô∏è **Temps estim√© sur T4 GPU : ~45-60 minutes** (vs ~25 min avant)

In [None]:
# ============================================================
# 6. CONFIGURATION D'ENTRA√éNEMENT (AM√âLIOR√âE)
# ============================================================

CONFIG = {
    # Architecture - AUGMENT√âE pour meilleure capacit√©
    'latent_dim': 128,          # 64 -> 128 pour plus de capacit√©
    'memory_size': 10,
    'num_heads': 4,
    'seq_length': 20,
    
    # Training - AJUST√â pour meilleure convergence
    'batch_size': 16,           # 32 -> 16 pour gradients plus stables
    'epochs': 150,              # 50 -> 150 pour vraiment converger
    'learning_rate': 5e-4,      # 1e-3 -> 5e-4 plus doux
    
    # Data
    'num_train_samples': 3000,
    'num_val_samples': 500,
    
    # Loss weighting
    'latent_loss_weight': 0.01,  # 0.1 -> 0.01 focus sur reconstruction
}

print("="*60)
print("CONFIGURATION AM√âLIOR√âE POUR RECONSTRUCTION")
print("="*60)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

In [None]:
# ============================================================
# 7. FONCTIONS D'ENTRA√éNEMENT (AM√âLIOR√âES)
# ============================================================

def create_train_step(model, optimizer, latent_weight=0.01):
    """Cr√©e une fonction d'entra√Ænement sp√©cifique au mod√®le."""
    @tf.function
    def train_step(batch):
        with tf.GradientTape() as tape:
            reconstructed, z_true, z_pred, _ = model(batch, training=True)
            
            # Loss reconstruction PRINCIPALE (MSE pour meilleurs gradients)
            loss_mse = tf.reduce_mean(tf.square(batch - reconstructed))
            
            # Loss BCE additionnelle
            loss_bce = tf.reduce_mean(keras.losses.binary_crossentropy(batch, reconstructed))
            
            # Loss reconstruction combin√©e
            loss_rec = loss_mse + 0.5 * loss_bce
            
            # Loss pr√©diction latente (r√©duite)
            loss_latent = tf.reduce_mean(tf.square(z_true[:, 1:] - z_pred[:, :-1]))
            
            total_loss = loss_rec + latent_weight * loss_latent
        
        gradients = tape.gradient(total_loss, model.trainable_variables)
        gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        # SSIM
        B, T = tf.shape(batch)[0], tf.shape(batch)[1]
        y_true = tf.reshape(batch, (B * T, 64, 64, 1))
        y_pred = tf.reshape(reconstructed, (B * T, 64, 64, 1))
        ssim = tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
        
        return total_loss, ssim, loss_mse
    return train_step

def create_val_step(model, latent_weight=0.01):
    """Cr√©e une fonction de validation sp√©cifique au mod√®le."""
    @tf.function
    def val_step(batch):
        reconstructed, z_true, z_pred, _ = model(batch, training=False)
        
        loss_mse = tf.reduce_mean(tf.square(batch - reconstructed))
        loss_bce = tf.reduce_mean(keras.losses.binary_crossentropy(batch, reconstructed))
        loss_rec = loss_mse + 0.5 * loss_bce
        loss_latent = tf.reduce_mean(tf.square(z_true[:, 1:] - z_pred[:, :-1]))
        total_loss = loss_rec + latent_weight * loss_latent
        
        B, T = tf.shape(batch)[0], tf.shape(batch)[1]
        y_true = tf.reshape(batch, (B * T, 64, 64, 1))
        y_pred = tf.reshape(reconstructed, (B * T, 64, 64, 1))
        ssim = tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
        
        return total_loss, ssim
    return val_step

def train_model(model, model_name, epochs=150):
    """Entra√Æne un mod√®le avec learning rate scheduler."""
    print(f"\n{'='*60}")
    print(f"Entra√Ænement {model_name} (AM√âLIOR√â)")
    print(f"{'='*60}")
    
    # Optimizer avec learning rate decay
    lr_schedule = keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=CONFIG['learning_rate'],
        decay_steps=epochs * (CONFIG['num_train_samples'] // CONFIG['batch_size']),
        alpha=0.1  # Decay to 10% of initial LR
    )
    optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
    
    # Cr√©er les fonctions sp√©cifiques
    train_step = create_train_step(model, optimizer, CONFIG['latent_loss_weight'])
    val_step = create_val_step(model, CONFIG['latent_loss_weight'])
    
    history = {'train_loss': [], 'val_loss': [], 'train_ssim': [], 'val_ssim': []}
    best_val_loss = float('inf')
    patience_counter = 0
    patience = 20  # Early stopping patience
    
    # G√©n√©rer les donn√©es
    print("G√©n√©ration des donn√©es...")
    train_data = data_gen.generate_batch(CONFIG['num_train_samples'])
    val_data = data_gen.generate_batch(CONFIG['num_val_samples'], use_test=True)
    
    num_batches = CONFIG['num_train_samples'] // CONFIG['batch_size']
    num_val_batches = CONFIG['num_val_samples'] // CONFIG['batch_size']
    
    print(f"D√©marrage ({epochs} epochs, {num_batches} batches/epoch, LR cosine decay)...\n")
    
    for epoch in range(epochs):
        # Shuffle
        indices = np.random.permutation(CONFIG['num_train_samples'])
        train_data_shuffled = train_data[indices]
        
        # Training
        train_losses, train_ssims = [], []
        for i in range(num_batches):
            batch = train_data_shuffled[i*CONFIG['batch_size']:(i+1)*CONFIG['batch_size']]
            loss, ssim, mse = train_step(batch)
            train_losses.append(loss.numpy())
            train_ssims.append(ssim.numpy())
        
        # Validation
        val_losses, val_ssims = [], []
        for i in range(num_val_batches):
            batch = val_data[i*CONFIG['batch_size']:(i+1)*CONFIG['batch_size']]
            loss, ssim = val_step(batch)
            val_losses.append(loss.numpy())
            val_ssims.append(ssim.numpy())
        
        # Moyennes
        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)
        train_ssim = np.mean(train_ssims)
        val_ssim = np.mean(val_ssims)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_ssim'].append(train_ssim)
        history['val_ssim'].append(val_ssim)
        
        # Affichage (tous les 10 epochs ou d√©but/fin)
        if epoch % 10 == 0 or epoch == epochs - 1 or epoch < 5:
            print(f"Epoch {epoch+1:3d}/{epochs} | "
                  f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
                  f"Train SSIM: {train_ssim:.4f} | Val SSIM: {val_ssim:.4f}")
        
        # Sauvegarde meilleur mod√®le
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            model.save_weights(f'outputs/{model_name}_best.weights.h5')
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Early stopping (d√©sactiv√© pour training complet)
        # if patience_counter >= patience:
        #     print(f"\n‚ö†Ô∏è Early stopping at epoch {epoch+1}")
        #     break
    
    # Sauvegarde finale
    model.save_weights(f'outputs/{model_name}_final.weights.h5')
    print(f"\n‚úì {model_name} termin√©! Best Val Loss: {best_val_loss:.4f}")
    print(f"  Final Val SSIM: {val_ssim:.4f}")
    
    return history

print("‚úì Fonctions d'entra√Ænement AM√âLIOR√âES d√©finies")
print("  - Loss: MSE + 0.5*BCE (meilleur pour pixels)")
print("  - Learning Rate: Cosine Decay")
print("  - Latent loss weight r√©duit: 0.01")

In [None]:
# ============================================================
# 8. ENTRA√éNEMENT MA-TAP (avec latent_dim=128)
# ============================================================

# Cr√©er le mod√®le MA-TAP avec capacit√© augment√©e
matap_model = MATAPModel(
    latent_dim=CONFIG['latent_dim'],  # 128 maintenant
    memory_size=CONFIG['memory_size'],
    num_heads=CONFIG['num_heads']
)

# Build
dummy = tf.zeros((1, CONFIG['seq_length'], 64, 64, 1))
_ = matap_model(dummy)
num_params = sum([tf.reduce_prod(v.shape).numpy() for v in matap_model.trainable_variables])
print(f"MA-TAP param√®tres: {num_params:,}")
print(f"  (latent_dim={CONFIG['latent_dim']}, memory_size={CONFIG['memory_size']})")

# Entra√Æner avec plus d'epochs
history_matap = train_model(matap_model, 'matap', epochs=CONFIG['epochs'])

In [None]:
# ============================================================
# 9. ENTRA√éNEMENT BASELINE (avec latent_dim=128)
# ============================================================

# Cr√©er le mod√®le Baseline avec capacit√© augment√©e
baseline_model = BaselineTAPModel(latent_dim=CONFIG['latent_dim'])  # 128 maintenant

# Build
dummy_baseline = tf.zeros((2, CONFIG['seq_length'], 64, 64, 1))
_ = baseline_model(dummy_baseline, training=True)

# Warmup pour initialiser les variables
print("Warmup du mod√®le Baseline...")
warmup_optimizer = keras.optimizers.Adam(learning_rate=1e-4)
with tf.GradientTape() as tape:
    out, z_true, z_pred, _ = baseline_model(dummy_baseline, training=True)
    warmup_loss = tf.reduce_mean(out)
grads = tape.gradient(warmup_loss, baseline_model.trainable_variables)
warmup_optimizer.apply_gradients(zip(grads, baseline_model.trainable_variables))

num_params = sum([tf.reduce_prod(v.shape).numpy() for v in baseline_model.trainable_variables])
print(f"Baseline param√®tres: {num_params:,}")
print(f"  (latent_dim={CONFIG['latent_dim']})")

# Recr√©er le mod√®le pour repartir de z√©ro apr√®s warmup
baseline_model = BaselineTAPModel(latent_dim=CONFIG['latent_dim'])
dummy_baseline = tf.zeros((2, CONFIG['seq_length'], 64, 64, 1))
_ = baseline_model(dummy_baseline, training=True)

# Entra√Æner avec plus d'epochs
history_baseline = train_model(baseline_model, 'baseline', epochs=CONFIG['epochs'])

## 10. Visualisations et R√©sultats

In [None]:
# ============================================================
# 10. G√âN√âRATION DES FIGURES
# ============================================================

# Figure 1: Courbes d'apprentissage
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history_matap['train_loss'], 'b-', label='MA-TAP Train', linewidth=2)
axes[0].plot(history_matap['val_loss'], 'b--', label='MA-TAP Val', linewidth=2)
axes[0].plot(history_baseline['train_loss'], 'r-', label='Baseline Train', linewidth=2)
axes[0].plot(history_baseline['val_loss'], 'r--', label='Baseline Val', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training & Validation Loss', fontsize=14)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# SSIM
axes[1].plot(history_matap['train_ssim'], 'b-', label='MA-TAP Train', linewidth=2)
axes[1].plot(history_matap['val_ssim'], 'b--', label='MA-TAP Val', linewidth=2)
axes[1].plot(history_baseline['train_ssim'], 'r-', label='Baseline Train', linewidth=2)
axes[1].plot(history_baseline['val_ssim'], 'r--', label='Baseline Val', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('SSIM', fontsize=12)
axes[1].set_title('Structural Similarity Index', fontsize=14)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('figures/training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úì Figure sauvegard√©e: figures/training_curves.png")

In [None]:
# ============================================================
# 11. √âVALUATION SSIM PAR TIMESTEP
# ============================================================

def compute_ssim_per_timestep(model, test_data):
    """Calcule SSIM pour chaque timestep."""
    reconstructed, _, _, _ = model(test_data, training=False)
    T = test_data.shape[1]
    ssim_per_t = []
    for t in range(T):
        ssim = tf.image.ssim(test_data[:, t], reconstructed[:, t], max_val=1.0)
        ssim_per_t.append(tf.reduce_mean(ssim).numpy())
    return ssim_per_t

# G√©n√©rer donn√©es de test
test_data = data_gen.generate_batch(100, use_test=True)

# Calculer SSIM
ssim_matap = compute_ssim_per_timestep(matap_model, test_data)
ssim_baseline = compute_ssim_per_timestep(baseline_model, test_data)

# Figure 2: SSIM temporel
plt.figure(figsize=(10, 6))
timesteps = range(len(ssim_matap))
plt.plot(timesteps, ssim_matap, 'b-o', linewidth=2, markersize=6, label='MA-TAP (Ours)')
plt.plot(timesteps, ssim_baseline, 'r--s', linewidth=2, markersize=6, label='Baseline (GRU)')
plt.xlabel('Timestep', fontsize=12)
plt.ylabel('SSIM', fontsize=12)
plt.title('Temporal Coherence: SSIM over Time', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.ylim([0, 1])
plt.savefig('figures/ssim_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úì Figure sauvegard√©e: figures/ssim_comparison.png")

In [None]:
# ============================================================
# 12. VISUALISATION DES RECONSTRUCTIONS
# ============================================================

# Prendre une s√©quence de test
sample = test_data[0:1]  # (1, T, 64, 64, 1)
rec_matap, _, _, _ = matap_model(sample, training=False)
rec_baseline, _, _, _ = baseline_model(sample, training=False)

# Afficher
T = sample.shape[1]
fig, axes = plt.subplots(3, min(T, 10), figsize=(20, 6))

for t in range(min(T, 10)):
    axes[0, t].imshow(sample[0, t, :, :, 0], cmap='gray', vmin=0, vmax=1)
    axes[0, t].axis('off')
    axes[0, t].set_title(f't={t}', fontsize=9)
    
    axes[1, t].imshow(rec_matap[0, t, :, :, 0].numpy(), cmap='gray', vmin=0, vmax=1)
    axes[1, t].axis('off')
    
    axes[2, t].imshow(rec_baseline[0, t, :, :, 0].numpy(), cmap='gray', vmin=0, vmax=1)
    axes[2, t].axis('off')

axes[0, 0].set_ylabel('Ground Truth', fontsize=11)
axes[1, 0].set_ylabel('MA-TAP', fontsize=11)
axes[2, 0].set_ylabel('Baseline', fontsize=11)

plt.suptitle('Reconstruction Comparison', fontsize=14)
plt.tight_layout()
plt.savefig('figures/reconstruction_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úì Figure sauvegard√©e: figures/reconstruction_comparison.png")

In [None]:
# ============================================================
# 13. R√âSUM√â DES R√âSULTATS
# ============================================================

print("="*60)
print("R√âSULTATS FINAUX")
print("="*60)

print(f"\nüìä MA-TAP:")
print(f"   Final Train Loss: {history_matap['train_loss'][-1]:.4f}")
print(f"   Final Val Loss:   {history_matap['val_loss'][-1]:.4f}")
print(f"   Final Train SSIM: {history_matap['train_ssim'][-1]:.4f}")
print(f"   Final Val SSIM:   {history_matap['val_ssim'][-1]:.4f}")
print(f"   Best Val Loss:    {min(history_matap['val_loss']):.4f}")

print(f"\nüìä Baseline:")
print(f"   Final Train Loss: {history_baseline['train_loss'][-1]:.4f}")
print(f"   Final Val Loss:   {history_baseline['val_loss'][-1]:.4f}")
print(f"   Final Train SSIM: {history_baseline['train_ssim'][-1]:.4f}")
print(f"   Final Val SSIM:   {history_baseline['val_ssim'][-1]:.4f}")
print(f"   Best Val Loss:    {min(history_baseline['val_loss']):.4f}")

# Am√©lioration
ssim_improvement = (history_matap['val_ssim'][-1] - history_baseline['val_ssim'][-1]) / history_baseline['val_ssim'][-1] * 100
print(f"\nüéØ Am√©lioration SSIM MA-TAP vs Baseline: {ssim_improvement:+.2f}%")

print("\n" + "="*60)

In [None]:
# ============================================================
# 14. T√âL√âCHARGEMENT DES R√âSULTATS
# ============================================================

# Cr√©er un zip avec tous les r√©sultats
import shutil

# Sauvegarder les historiques
np.savez('outputs/training_history.npz', 
         matap_train_loss=history_matap['train_loss'],
         matap_val_loss=history_matap['val_loss'],
         matap_train_ssim=history_matap['train_ssim'],
         matap_val_ssim=history_matap['val_ssim'],
         baseline_train_loss=history_baseline['train_loss'],
         baseline_val_loss=history_baseline['val_loss'],
         baseline_train_ssim=history_baseline['train_ssim'],
         baseline_val_ssim=history_baseline['val_ssim'])

# Cr√©er le zip
shutil.make_archive('MA_TAP_Results', 'zip', '.', 'outputs')
shutil.make_archive('MA_TAP_Figures', 'zip', '.', 'figures')

print("‚úì Fichiers cr√©√©s:")
print("  - MA_TAP_Results.zip (poids des mod√®les + historiques)")
print("  - MA_TAP_Figures.zip (toutes les figures)")
print("\nüì• T√©l√©chargez ces fichiers depuis le panneau de gauche!")

# Pour Google Colab - t√©l√©chargement automatique
try:
    from google.colab import files
    files.download('MA_TAP_Results.zip')
    files.download('MA_TAP_Figures.zip')
except:
    print("(T√©l√©chargement manuel depuis le panneau de fichiers)")

---

## ‚úÖ Entra√Ænement Termin√©!

### Fichiers g√©n√©r√©s:
- `outputs/matap_best.weights.h5` - Meilleurs poids MA-TAP
- `outputs/baseline_best.weights.h5` - Meilleurs poids Baseline
- `outputs/training_history.npz` - Historiques d'entra√Ænement
- `figures/training_curves.png` - Courbes d'apprentissage
- `figures/ssim_comparison.png` - Comparaison SSIM temporel
- `figures/reconstruction_comparison.png` - Visualisation reconstructions

### Pour utiliser ces r√©sultats dans ton projet local:
1. T√©l√©charge les fichiers `.zip`
2. Copie les `.weights.h5` dans `Part3/experiments/outputs/`
3. Copie les `.png` dans `Part3/experiments/outputs/`

---