In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.multiclass import unique_labels

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (Conv2D, MaxPooling2D, Flatten,
                                     Dense, Dropout, BatchNormalization)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping

In [None]:
# ===  Ruta del dataset ===
DATASET_DIR = Path('./malimg_dataset/malimg_paper_dataset_imgs')

# === Preprocesamiento de datos ===
filepaths = list(DATASET_DIR.glob('**/*.png'))
labels = [p.parent.name for p in filepaths]

df = pd.DataFrame({
    'filepath': filepaths,
    'label': labels
})
df['filepath'] = df['filepath'].astype(str)

# Filtrar clases con menos de 100 muestras
df = df.groupby('label').filter(lambda x: len(x) >= 100)

# Mostrar clases y conteo
print(f"Total de imágenes: {len(df)}")
print(f"Familias únicas: {df['label'].nunique()}")
print(df['label'].value_counts())


In [None]:
# === Visualizar algunas imágenes de ejemplo ===
fig, axes = plt.subplots(3, 5, figsize=(20, 12))
for ax, (idx, row) in zip(axes.flatten(), df.sample(15).iterrows()):
    img = plt.imread(row['filepath'])
    ax.imshow(img, cmap='gray')
    ax.set_title(row['label'], fontsize=9)
    ax.axis('off')
plt.tight_layout()
plt.show()


In [None]:
# === Preparación de datos ===
IMG_SIZE = (128, 128)

train_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.3,
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True
)

train_generator = train_datagen.flow_from_dataframe(
    df,
    x_col='filepath',
    y_col='label',
    target_size=IMG_SIZE,
    batch_size=32,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

valid_generator = train_datagen.flow_from_dataframe(
    df,
    x_col='filepath',
    y_col='label',
    target_size=IMG_SIZE,
    batch_size=32,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

In [None]:
# === Construcción del modelo CNN ===
model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(128, 128, 3)),
    BatchNormalization(),
    MaxPooling2D(2,2),

    Conv2D(64, (3,3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D(2,2),

    Conv2D(128, (3,3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D(2,2),

    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(len(train_generator.class_indices), activation='softmax')
])

model.summary()

In [None]:
# === Compilación ===
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])


In [None]:
# === Entrenamiento con EarlyStopping ===
early = EarlyStopping(patience=5, restore_best_weights=True)
history = model.fit(
    train_generator,
    validation_data=valid_generator,
    epochs=50,
    callbacks=[early]
)


In [None]:
# === Evaluación ===
val_loss, val_accuracy = model.evaluate(valid_generator)
print(f"Loss en validación: {val_loss:.4f}")
print(f"Accuracy en validación: {val_accuracy:.4f}")


In [None]:
# ===  Métricas ===
y_pred = model.predict(valid_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = valid_generator.classes

labels_presentes = unique_labels(y_true, y_pred_classes)
nombres_presentes = [k for k, v in valid_generator.class_indices.items() if v in labels_presentes]

print(classification_report(y_true, y_pred_classes, target_names=nombres_presentes))

conf_mat = confusion_matrix(y_true, y_pred_classes)
plt.figure(figsize=(12,10))
sns.heatmap(conf_mat, annot=False, cmap='Blues')
plt.title('Matriz de Confusión')
plt.xlabel('Predicho')
plt.ylabel('Real')
plt.show()

In [None]:
# === Gráficas de entrenamiento ===
plt.figure(figsize=(10,4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Entrenamiento')
plt.plot(history.history['val_accuracy'], label='Validación')
plt.title('Precisión')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Entrenamiento')
plt.plot(history.history['val_loss'], label='Validación')
plt.title('Pérdida')
plt.legend()
plt.tight_layout()
plt.show()
