# TP — Transfer Learning (ResNet50) sur ShipsNet (Kaggle)

Objectifs :
1. Charger le dataset **Ships in Satellite Imagery (ShipsNet)** depuis Kaggle (via `kagglehub`).
2. Préparer les données (décodage JSON, split train/val/test, pré-traitement ResNet50).
3. Construire un modèle **ResNet50** pré-entraîné (ImageNet) + tête de classification.
4. Entraîner en 2 phases : **(i) head-only** puis **(ii) fine-tuning**.
5. Évaluer : accuracy, matrice de confusion, rapport de classification.

> Remarque : notebook conçu pour être exécutable “Run all”.  
> Si vous exécutez dans un environnement sans accès réseau, téléchargez le dataset localement et pointez `DATA_DIR`.


## 0. Imports et reproductibilité

In [None]:
import os
import json
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers as L

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

SEED = 42
tf.keras.utils.set_random_seed(SEED)
np.random.seed(SEED)

print("TF version:", tf.__version__)


## 1. Téléchargement et localisation du dataset

In [None]:
# Option A (recommandée) : téléchargement via kagglehub
# Requiert accès internet + éventuelle config Kaggle (selon environnement).

USE_KAGGLEHUB = True
DATA_DIR = None

if USE_KAGGLEHUB:
    try:
        import kagglehub
        DATA_DIR = Path(kagglehub.dataset_download("rhammell/ships-in-satellite-imagery"))
        print("Dataset téléchargé dans :", DATA_DIR)
    except Exception as e:
        print("Téléchargement via kagglehub a échoué:", repr(e))
        USE_KAGGLEHUB = False

if not USE_KAGGLEHUB:
    # Option B : chemin local (à adapter)
    # DATA_DIR = Path("/path/to/shipsnet")
    raise ValueError("Définissez DATA_DIR manuellement si USE_KAGGLEHUB=False.")

assert DATA_DIR is not None and DATA_DIR.exists(), "DATA_DIR invalide. Vérifiez téléchargement / chemin."


## 2. Chargement de ShipsNet (JSON)

In [None]:
def find_shipsnet_json(root: Path) -> Path:
    candidates = []
    for p in root.rglob("*.json"):
        name = p.name.lower()
        if "shipsnet" in name or ("ship" in name and "net" in name):
            candidates.append(p)
    if not candidates:
        all_json = list(root.rglob("*.json"))
        if not all_json:
            raise FileNotFoundError("Aucun fichier .json trouvé dans DATA_DIR.")
        all_json.sort(key=lambda x: x.stat().st_size, reverse=True)
        return all_json[0]
    candidates.sort(key=lambda x: x.stat().st_size, reverse=True)
    return candidates[0]

json_path = find_shipsnet_json(DATA_DIR)
print("JSON trouvé:", json_path)

with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

def get_key(d, keys):
    for k in keys:
        if k in d:
            return k
    return None

x_key = get_key(data, ["data", "X", "x", "images"])
y_key = get_key(data, ["labels", "y", "Y", "target", "targets"])

if x_key is None or y_key is None:
    raise KeyError(f"Impossible d'identifier les clés X/y. Clés disponibles: {list(data.keys())}")

X_raw = np.array(data[x_key])
y_raw = np.array(data[y_key])

print("X_raw shape:", X_raw.shape, "dtype:", X_raw.dtype)
print("y_raw shape:", y_raw.shape, "unique:", np.unique(y_raw, return_counts=True))

# Reshape si flatten
if X_raw.ndim == 2:
    n, d = X_raw.shape
    if d == 80 * 80 * 3:
        H, W, C = 80, 80, 3
    else:
        if d % 3 != 0:
            raise ValueError(f"Dimension flatten inattendue: {d} (pas divisible par 3).")
        hw = d // 3
        s = int(round(np.sqrt(hw)))
        if s * s != hw:
            raise ValueError(f"Impossible d'inférer H=W depuis hw={hw}.")
        H, W, C = s, s, 3
    X = X_raw.reshape(n, H, W, C)
elif X_raw.ndim == 4:
    X = X_raw
    H, W, C = X.shape[1], X.shape[2], X.shape[3]
else:
    raise ValueError(f"Format X_raw inattendu: ndim={X_raw.ndim}")

X = X.astype("float32")
y = y_raw.astype("int64")

print("X shape:", X.shape, "range:", (X.min(), X.max()))
print("y shape:", y.shape, "classes:", np.unique(y))

plt.figure(figsize=(8, 3))
for i in range(6):
    plt.subplot(1, 6, i + 1)
    img = X[i]
    # si valeurs 0..255, afficher en uint8
    if img.max() > 1.5:
        img = img.astype("uint8")
    plt.imshow(img)
    plt.axis("off")
    plt.title(int(y[i]))
plt.suptitle("Exemples (label 0/1)")
plt.show()


## 3. Split train/val/test et prétraitement ResNet50

In [None]:
NUM_CLASSES = 2

# Split stratifié : train 70%, val 15%, test 15%
X_train, X_tmp, y_train, y_tmp = train_test_split(
    X, y, test_size=0.30, random_state=SEED, stratify=y
)
X_val, X_test, y_val, y_test = train_test_split(
    X_tmp, y_tmp, test_size=0.50, random_state=SEED, stratify=y_tmp
)

print("Train:", X_train.shape, y_train.shape)
print("Val  :", X_val.shape, y_val.shape)
print("Test :", X_test.shape, y_test.shape)

y_train_oh = keras.utils.to_categorical(y_train, NUM_CLASSES)
y_val_oh   = keras.utils.to_categorical(y_val, NUM_CLASSES)
y_test_oh  = keras.utils.to_categorical(y_test, NUM_CLASSES)

BATCH_SIZE = 32
IMG_SIZE = (224, 224)

preprocess = keras.applications.resnet50.preprocess_input

def make_ds(X_arr, y_arr_oh, training: bool):
    ds = tf.data.Dataset.from_tensor_slices((X_arr, y_arr_oh))
    if training:
        ds = ds.shuffle(2048, seed=SEED, reshuffle_each_iteration=True)
    ds = ds.map(lambda x, y: (tf.image.resize(x, IMG_SIZE), y), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.map(lambda x, y: (preprocess(x), y), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    return ds

train_ds = make_ds(X_train, y_train_oh, training=True)
val_ds   = make_ds(X_val,   y_val_oh,   training=False)
test_ds  = make_ds(X_test,  y_test_oh,  training=False)

xb, yb = next(iter(train_ds))
print("Batch X:", xb.shape, xb.dtype, "Batch y:", yb.shape, yb.dtype)


## 4. Modèle : ResNet50 + tête Dense

In [None]:
def build_resnet50_transfer(num_classes=2, input_shape=(224, 224, 3), dropout=0.3):
    inputs = L.Input(shape=input_shape)

    base = keras.applications.ResNet50(
        include_top=False,
        weights="imagenet",
        input_tensor=inputs
    )
    base.trainable = False  # Phase 1: on gèle ResNet

    x = base.output
    x = L.GlobalAveragePooling2D()(x)
    x = L.Dense(256, activation="relu")(x)
    x = L.Dropout(dropout)(x)
    outputs = L.Dense(num_classes, activation="softmax")(x)

    model = keras.Model(inputs, outputs, name="resnet50_shipsnet")
    return model, base

model_ship, base_model = build_resnet50_transfer(num_classes=NUM_CLASSES)
model_ship.summary()


## 5. Entraînement — Phase 1 (head-only)

In [None]:
model_ship.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

callbacks = [
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2, min_lr=1e-6),
]

history_1 = model_ship.fit(
    train_ds,
    validation_data=val_ds,
    epochs=15,
    callbacks=callbacks,
    verbose=1
)


### Courbes (Phase 1)

In [None]:
def plot_history(hist, title_prefix=""):
    plt.figure()
    plt.plot(hist.history["loss"])
    plt.plot(hist.history["val_loss"])
    plt.legend(["train", "val"])
    plt.title(f"{title_prefix}Loss")
    plt.xlabel("epoch")
    plt.show()

    plt.figure()
    plt.plot(hist.history["accuracy"])
    plt.plot(hist.history["val_accuracy"])
    plt.legend(["train", "val"])
    plt.title(f"{title_prefix}Accuracy")
    plt.xlabel("epoch")
    plt.show()

plot_history(history_1, title_prefix="Phase 1 — ")


## 6. Fine-tuning — Phase 2 (dé-gel partiel)

In [None]:
# Bonnes pratiques:
# - dé-geler uniquement le haut du backbone
# - geler BatchNorm (souvent recommandé)
# - recompiler avec un LR plus faible

base_model.trainable = True

for layer in base_model.layers:
    if isinstance(layer, L.BatchNormalization):
        layer.trainable = False

fine_tune_from = None
for i, layer in enumerate(base_model.layers):
    if "conv5" in layer.name:
        fine_tune_from = i
        break

if fine_tune_from is None:
    fine_tune_from = int(len(base_model.layers) * 2 / 3)

for layer in base_model.layers[:fine_tune_from]:
    layer.trainable = False

print("Fine-tune from layer index:", fine_tune_from, "/", len(base_model.layers))
print("Base trainable layers:", sum(l.trainable for l in base_model.layers), "/", len(base_model.layers))

model_ship.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

history_2 = model_ship.fit(
    train_ds,
    validation_data=val_ds,
    epochs=8,
    callbacks=callbacks,
    verbose=1
)


### Courbes (Phase 2)

In [None]:
plot_history(history_2, title_prefix="Phase 2 — ")


## 7. Évaluation finale (test)

In [None]:
test_loss, test_acc = model_ship.evaluate(test_ds, verbose=1)
print(f"Test loss: {test_loss:.4f} | Test accuracy: {test_acc:.4f}")

y_prob = model_ship.predict(test_ds, verbose=0)
y_pred = np.argmax(y_prob, axis=1)

print("\nClassification report:")
print(classification_report(y_test, y_pred, target_names=["no-ship", "ship"]))

cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["no-ship", "ship"])
disp.plot(values_format="d")
plt.title("Confusion matrix (test)")
plt.show()


## 8. Sauvegarde du modèle

In [None]:
EXPORT_DIR = Path("export_resnet50_shipsnet")
EXPORT_DIR.mkdir(exist_ok=True)

model_path = EXPORT_DIR / "model_ship.keras"
model_ship.save(model_path)
print("Modèle sauvegardé:", model_path.resolve())
