In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# ---------------------------
# 1. Setări
# ---------------------------
DATASET_PATH = r"Imagini\TrainImages"
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
SEED = 42
EPOCHS = 40
AUTOTUNE = tf.data.AUTOTUNE

# ---------------------------
# 2. Încarcă datele
# ---------------------------
train_ds_raw = tf.keras.preprocessing.image_dataset_from_directory(
    DATASET_PATH,
    validation_split=0.2,
    subset="training",
    seed=SEED,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)
val_ds_raw = tf.keras.preprocessing.image_dataset_from_directory(
    DATASET_PATH,
    validation_split=0.2,
    subset="validation",
    seed=SEED,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)

class_names = train_ds_raw.class_names
num_classes = len(class_names)
print(" Clase:", class_names)

# ---------------------------
# 3. Augmentare + pipeline
# ---------------------------
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
])

train_ds = train_ds_raw.map(lambda x, y: (data_augmentation(x, training=True), y)).prefetch(AUTOTUNE)
val_ds = val_ds_raw.prefetch(AUTOTUNE)

# ---------------------------
# 4. Class Weights
# ---------------------------
y_labels = []
for _, labels in train_ds_raw:
    y_labels.extend(labels.numpy())

class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(y_labels),
    y=y_labels
)
class_weights_dict = dict(enumerate(class_weights))
print(" Class Weights:", class_weights_dict)

# ---------------------------
# 5. Model CNN simplu
# ---------------------------
def build_simple_cnn(input_shape, num_classes):
    return models.Sequential([
        layers.Rescaling(1./255, input_shape=input_shape),

        layers.Conv2D(32, (3,3), activation='relu'),
        layers.MaxPooling2D(),

        layers.Conv2D(64, (3,3), activation='relu'),
        layers.MaxPooling2D(),

        layers.Conv2D(128, (3,3), activation='relu'),
        layers.MaxPooling2D(),

        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='softmax')
    ])

simple_cnn = build_simple_cnn((*IMG_SIZE, 3), num_classes)
simple_cnn.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# ---------------------------
# 6. Callbacks
# ---------------------------
early_stop = EarlyStopping(patience=5, restore_best_weights=True, monitor='val_loss')

# ---------------------------
# 7. Antrenare model
# ---------------------------
print(" Training CNN simplu...")
history = simple_cnn.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    class_weight=class_weights_dict,
    callbacks=[early_stop]
)

try:
    simple_cnn.save("simple_model.keras", save_format="keras")
    print(" Model salvat cu succes în formatul Keras.")
except TypeError as e:
    print(" Eroare la salvare completă:", e)
    print(" Salvăm separat arhitectura și greutățile (fallback).")

    #  Salvare fallback (arhitectură + greutăți)
    model_json = simple_cnn.to_json()
    with open("simple_model_architecture.json", "w") as json_file:
        json_file.write(model_json)

    simple_cnn.save_weights("simple_model_weights.h5")
    print(" Arhitectură și greutăți salvate separat.")

# ---------------------------
# 8. Plotare rezultate
# ---------------------------
def plot_history(history):
    plt.figure(figsize=(12,5))

    plt.subplot(1,2,1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Val Accuracy')
    plt.title('Acuratețe')
    plt.xlabel('Epoci')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.title('Pierdere')
    plt.xlabel('Epoci')
    plt.ylabel('Loss')
    plt.legend()

    plt.tight_layout()
    plt.show()

plot_history(history)

# ---------------------------
# 9. Evaluare pe setul de validare
# ---------------------------
print(" Evaluare finală pe setul de validare...")
y_true = []
y_pred = []

for images, labels in val_ds:
    preds = simple_cnn.predict(images)
    y_true.extend(labels.numpy())
    y_pred.extend(np.argmax(preds, axis=1))

print("\n Classification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))

# ---------------------------
# 10. Matrice de confuzie
# ---------------------------
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(' Matrice de Confuzie')
plt.show()
