In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import kagglehub

print("ðŸš€ TP TRM-VISION M2 â€“ Fashion-MNIST Kaggle (small vs large + classifier)")

# =============================================================================
# 0. Download Kaggle Fashion-MNIST
# =============================================================================
path = kagglehub.dataset_download("zalando-research/fashionmnist")
print("Path to dataset files:", path)

# Les fichiers du dataset Kaggle contiennent :
# - fashion-mnist_train.csv
# - fashion-mnist_test.csv
train_csv = os.path.join(path, "fashion-mnist_train.csv")
test_csv  = os.path.join(path, "fashion-mnist_test.csv")

# =============================================================================
# 1. TinyBlock
# =============================================================================
class TinyBlock(layers.Layer):
    def __init__(self, d):
        super().__init__()
        self.ln = layers.LayerNormalization()
        self.fc1 = layers.Dense(4 * d, activation="gelu")
        self.fc2 = layers.Dense(d)

    def call(self, u):
        h = self.ln(u)
        h = self.fc1(h)
        h = self.fc2(h)
        return u + h

# =============================================================================
# 2. TRM-VISION simplifiÃ©
# =============================================================================
class TRM_VISION(keras.Model):
    def __init__(self, img_size=16, d=32, n_rec=2, name_suffix="small"):
        super().__init__(name=f"TRM_VISION_{name_suffix}")
        self.img_size = img_size
        self.d = d
        self.n_rec = n_rec
        
        self.cond_emb = layers.Embedding(10, d)  # Fashion-MNIST : 10 classes (t-shirts, shoes, etc.)
        self.y0 = self.add_weight(
            shape=(1, img_size*img_size, d),
            initializer="zeros",
            trainable=True,
            name=f"y0_{name_suffix}"
        )
        self.z0 = self.add_weight(
            shape=(1, img_size*img_size, d),
            initializer="zeros",
            trainable=True,
            name=f"z0_{name_suffix}"
        )
        self.block1 = TinyBlock(d)
        self.block2 = TinyBlock(d)
        self.to_pixels = layers.Dense(1)

    def call(self, class_tokens, target_img=None, return_intermediate=False):
        B = tf.shape(class_tokens)[0]
        L = self.img_size * self.img_size
        
        # Embedding de la classe et broadcast sur le canvas
        c = self.cond_emb(class_tokens)              # [B, d]
        c = tf.tile(tf.expand_dims(c, 1), [1, L, 1]) # [B, L, d]
        
        # Ã‰tats initiaux
        y = tf.tile(self.y0, [B, 1, 1])
        z = tf.tile(self.z0, [B, 1, 1])
        
        inter_imgs = []

        # n_rec passes internes avec gradient
        for _ in range(self.n_rec):
            z = self.block2(self.block1(c + y + z))
            y = self.block2(self.block1(y + z))

        pixels = self.to_pixels(y)
        img = tf.tanh(tf.reshape(pixels, [B, self.img_size, self.img_size, 1]))
        if return_intermediate:
            inter_imgs.append(img)

        loss = None
        if target_img is not None:
            loss = tf.reduce_mean(tf.keras.losses.mse(target_img, img))

        if return_intermediate:
            return img, loss, inter_imgs
        if loss is None:
            return img
        return img, loss

# =============================================================================
# 3. Dataset Kaggle Fashion-MNIST â†’ 16x16
# =============================================================================

def load_fashion_mnist_csv(csv_path):
    df = pd.read_csv(csv_path)
    labels = df['label'].values.astype(np.int64)
    pixels = df.drop(columns=['label']).values.astype(np.float32)
    images = pixels.reshape(-1, 28, 28)
    return images, labels

x_train_full, y_train = load_fashion_mnist_csv(train_csv)
x_test_full,  y_test  = load_fashion_mnist_csv(test_csv)

# Normalisation [-1, 1]
x_train_full = x_train_full / 127.5 - 1.0
x_test_full  = x_test_full  / 127.5 - 1.0

# Resize 28x28 â†’ 16x16
x_train = np.squeeze(tf.image.resize(tf.expand_dims(x_train_full, -1), [16, 16]).numpy())
x_test  = np.squeeze(tf.image.resize(tf.expand_dims(x_test_full,  -1), [16, 16]).numpy())

print(f"âœ… Fashion-MNIST Kaggle: x_train={x_train.shape}, y_train={y_train.shape}")

# Sous-Ã©chantillonnage pour accÃ©lÃ©rer
train_idx = 8000
val_idx   = 2000

x_train_small = x_train[:train_idx]
y_train_small = y_train[:train_idx]
x_val_small   = x_test[:val_idx]
y_val_small   = y_test[:val_idx]

# =============================================================================
# 4. ModÃ¨les TRM small et large
# =============================================================================
trm_small = TRM_VISION(img_size=16, d=32, n_rec=2, name_suffix="small")
trm_large = TRM_VISION(img_size=16, d=64, n_rec=4, name_suffix="large")

trm_small.compile(optimizer='adam', loss='mse')
trm_large.compile(optimizer='adam', loss='mse')

print("\nðŸ”§ EntraÃ®nement TRM small (Fashion-MNIST)...")
hist_small = trm_small.fit(
    y_train_small, x_train_small,
    epochs=3, batch_size=256, verbose=1,
    validation_data=(y_val_small, x_val_small)
)

print("\nðŸ”§ EntraÃ®nement TRM large (Fashion-MNIST)...")
hist_large = trm_large.fit(
    y_train_small, x_train_small,
    epochs=3, batch_size=256, verbose=1,
    validation_data=(y_val_small, x_val_small)
)

# =============================================================================
# 5. Classifier CNN Fashion-MNIST 16x16
# =============================================================================
print("\nðŸ”§ EntraÃ®nement CNN classifier Fashion-MNIST (16x16)...")

cnn = keras.Sequential([
    layers.Input(shape=(16, 16, 1)),
    layers.Conv2D(32, 3, activation="relu"),
    layers.MaxPool2D(2),
    layers.Conv2D(64, 3, activation="relu"),
    layers.MaxPool2D(2),
    layers.Flatten(),
    layers.Dense(128, activation="relu"),
    layers.Dense(10, activation="softmax")
])

cnn.compile(optimizer="adam",
            loss="sparse_categorical_crossentropy",
            metrics=["accuracy"])

cnn.fit(
    x_train_small[..., np.newaxis], y_train_small,
    epochs=5, batch_size=256, verbose=1,
    validation_data=(x_val_small[..., np.newaxis], y_val_small)
)

# =============================================================================
# 6. Ã‰valuation: taux de reconnaissance des images gÃ©nÃ©rÃ©es
# =============================================================================
def eval_trm_with_cnn(trm_model, name, n_samples=500):
    labels = np.random.randint(0, 10, size=(n_samples,))
    gen_imgs = trm_model(tf.constant(labels)).numpy()  # [N,16,16,1]
    preds = cnn.predict(gen_imgs, verbose=0)
    pred_labels = preds.argmax(axis=1)
    acc = (pred_labels == labels).mean()
    print(f"ðŸ“Š Taux de reconnaissance CNN sur images Fashion-MNIST gÃ©nÃ©rÃ©es ({name}): {acc*100:.2f}%")
    return acc, gen_imgs, labels

acc_small, gen_small_imgs, gen_small_labels = eval_trm_with_cnn(trm_small, "TRM small")
acc_large, gen_large_imgs, gen_large_labels = eval_trm_with_cnn(trm_large, "TRM large")

# =============================================================================
# 7. Grilles de gÃ©nÃ©ration 0-9 pour les deux modÃ¨les (vÃªtements)
# =============================================================================
class_names = [
    "T-shirt/top","Trouser","Pullover","Dress","Coat",
    "Sandal","Shirt","Sneaker","Bag","Ankle boot"
]

def plot_grid(trm_model, title):
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    for i in range(10):
        gen_imgs = trm_model(tf.constant([i]))
        axes[i//5, i%5].imshow(gen_imgs[0, :, :, 0], cmap='gray', vmin=-1, vmax=1)
        axes[i//5, i%5].set_title(f"{i}: {class_names[i]}")
        axes[i//5, i%5].axis('off')
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

plot_grid(trm_small, "TRM-VISION SMALL: GÃ©nÃ©ration Fashion-MNIST 0-9")
plot_grid(trm_large, "TRM-VISION LARGE: GÃ©nÃ©ration Fashion-MNIST 0-9")

# =============================================================================
# 8. Target vs gÃ©nÃ©rÃ© pour une classe (ex: classe 5 = Sandal)
# =============================================================================
def plot_target_vs_gen(model, x_test_arr, y_test_arr, digit=5, title_prefix="TRM"):
    idx = np.where(y_test_arr == digit)[0][:4]
    target = x_test_arr[idx]
    gen = model(tf.constant([digit]*4)).numpy()
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for i in range(4):
        axes[0,i].imshow(target[i], cmap='gray', vmin=-1, vmax=1)
        axes[0,i].set_title(f"Vrai {digit}: {class_names[digit]}")
        axes[0,i].axis('off')
        axes[1,i].imshow(gen[i, :, :, 0], cmap='gray', vmin=-1, vmax=1)
        axes[1,i].set_title(f"{title_prefix} GÃ©nÃ©rÃ©")
        axes[1,i].axis('off')
    plt.suptitle(f"Classe {digit} ({class_names[digit]}): RÃ©el vs {title_prefix}", fontsize=16)
    plt.tight_layout()
    plt.show()

plot_target_vs_gen(trm_small, x_test, y_test, digit=5, title_prefix="TRM small")
plot_target_vs_gen(trm_large, x_test, y_test, digit=5, title_prefix="TRM large")

# =============================================================================
# 9. Visualisation trajectoire interne (modÃ¨le large, une classe)
# =============================================================================
digit_demo = 3  # Dress
_, _, inter_imgs = trm_large(tf.constant([digit_demo]), return_intermediate=True)

fig, axes = plt.subplots(1, len(inter_imgs), figsize=(4*len(inter_imgs), 4))
if len(inter_imgs) == 1:
    axes = [axes]

for i, img_t in enumerate(inter_imgs):
    axes[i].imshow(img_t[0, :, :, 0].numpy(), cmap='gray', vmin=-1, vmax=1)
    axes[i].set_title(f"Step {i+1}")
    axes[i].axis('off')
plt.suptitle(f"Ã‰volution interne TRM large â€“ classe {digit_demo} ({class_names[digit_demo]})", fontsize=16)
plt.tight_layout()
plt.show()

# =============================================================================
# 10. RÃ©capitulatif console
# =============================================================================
val_acc_cnn = cnn.evaluate(x_val_small[...,np.newaxis], y_val_small, verbose=0)[1]*100
print("\n" + "="*70)
print("ðŸŽ‰ RÃ©sumÃ© TP TRM-VISION â€“ Fashion-MNIST Kaggle")
print(f"   â†’ TRM small: d=32, n_rec=2, loss fin:  {hist_small.history['loss'][-1]:.4f}")
print(f"   â†’ TRM large: d=64, n_rec=4, loss fin:  {hist_large.history['loss'][-1]:.4f}")
print(f"   â†’ CNN val accuracy: {val_acc_cnn:.2f}%")
print(f"   â†’ Reconnaissance CNN sur images Fashion-MNIST TRM small: {acc_small*100:.2f}%")
print(f"   â†’ Reconnaissance CNN sur images Fashion-MNIST TRM large: {acc_large*100:.2f}%")
print("   â†’ Figures: grilles 0-9 (vÃªtements), vrai vs gÃ©nÃ©rÃ©, trajectoire interne")
print("="*70)


  from .autonotebook import tqdm as notebook_tqdm


ðŸš€ TP TRM-VISION M2 â€“ Fashion-MNIST Kaggle (small vs large + classifier)
Downloading from https://www.kaggle.com/api/v1/datasets/download/zalando-research/fashionmnist?dataset_version_number=4...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 68.8M/68.8M [00:12<00:00, 5.86MB/s]

Extracting files...





Path to dataset files: C:\Users\User\.cache\kagglehub\datasets\zalando-research\fashionmnist\versions\4
âœ… Fashion-MNIST Kaggle: x_train=(60000, 16, 16), y_train=(60000,)

ðŸ”§ EntraÃ®nement TRM small (Fashion-MNIST)...
Epoch 1/3
Epoch 2/3
Epoch 3/3

ðŸ”§ EntraÃ®nement TRM large (Fashion-MNIST)...
Epoch 1/3
Epoch 2/3
Epoch 3/3