In [None]:
import os, json, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import (Input, Conv2D, DepthwiseConv2D, BatchNormalization,
                                     Activation, Add, Dropout, GlobalAveragePooling2D,
                                     Dense, Multiply, Reshape)
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import SGD
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support, roc_auc_score

import seaborn as sns

def efficient_channel_attention(x, ratio=8):
    ch = K.int_shape(x)[-1]
    if ch is None:
        return x
    se = GlobalAveragePooling2D()(x)
    se = Dense(max(1, ch // ratio), activation='relu', use_bias=True)(se)
    se = Dense(ch, activation='sigmoid', use_bias=True)(se)
    se = Reshape((1, 1, ch))(se)
    return Multiply()([x, se])

def inverted_residual_block(x_in, filters_out, strides=1, expansion_factor=4, use_attention=True, dropout_rate=0.0):
    shortcut = x_in
    filters_in = K.int_shape(x_in)[-1]
    if expansion_factor > 1:
        x = Conv2D(filters_in * expansion_factor, 1, padding='same', use_bias=False, kernel_initializer='he_normal')(x_in)
        x = BatchNormalization()(x)
        x = Activation('relu6')(x)
    else:
        x = x_in
    x = DepthwiseConv2D(3, strides=strides, padding='same', use_bias=False, depthwise_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu6')(x)
    if use_attention and filters_out >= 64:
        x = efficient_channel_attention(x, ratio=8)
    x = Conv2D(filters_out, 1, padding='same', use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    if dropout_rate > 0:
        x = Dropout(dropout_rate)(x)
    if strides == 1 and filters_in == filters_out:
        x = Add()([shortcut, x])
    return x

def create_MediNet_XG(input_shape, num_classes, width_multiplier=0.5):
    def make_divisible(v, divisor=8):
        new_v = max(divisor, int(v + divisor / 2) // divisor * divisor)
        if new_v < 0.9 * v:
            new_v += divisor
        return new_v

    inputs = Input(shape=input_shape)
    filters = make_divisible(16 * width_multiplier)
    x = Conv2D(filters, 3, strides=2, padding='same', use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu6')(x)

    stage_configs = [
        (24, 2, 2, False),
        (32, 3, 2, False),
        (64, 4, 2, True),
        (96, 6, 1, True),
    ]
    for f, expansion, stride, use_attn in stage_configs:
        f = make_divisible(f * width_multiplier)
        x = inverted_residual_block(x, f, strides=stride, expansion_factor=expansion, use_attention=use_attn, dropout_rate=0.1)
        x = inverted_residual_block(x, f, strides=1, expansion_factor=expansion, use_attention=use_attn, dropout_rate=0.0)

    final_filters = make_divisible(320 * width_multiplier)
    x = Conv2D(final_filters, 1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu6')(x)
    emb = GlobalAveragePooling2D(name="embedding")(x)
    x = Dropout(0.2)(emb)
    outputs = Dense(num_classes, activation='softmax')(x)
    return Model(inputs=inputs, outputs=outputs, name="MediNet_XG")

def get_callbacks(model_name):
    return [
        ModelCheckpoint(f"{OUTPUT_DIR}/best_loss_{model_name}.keras", monitor="val_loss", save_best_only=True, mode="min", verbose=0),
        ModelCheckpoint(f"{OUTPUT_DIR}/best_acc_{model_name}.keras", monitor="val_accuracy", save_best_only=True, mode="max", verbose=0),
        EarlyStopping(monitor="val_loss", patience=15, restore_best_weights=True, verbose=0),
        ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=7, min_lr=1e-6, verbose=0),
    ]

def evaluate_model(model, dataset):
    y_true, y_prob = [], []
    for x, y in dataset:
        y_true.append(y.numpy())
        y_prob.append(model.predict(x, verbose=0))
    y_true = np.vstack(y_true)
    y_prob = np.vstack(y_prob)
    y_true_idx = np.argmax(y_true, axis=1)
    y_pred_idx = np.argmax(y_prob, axis=1)

    loss, acc = model.evaluate(dataset, verbose=0)
    pr, rc, f1, _ = precision_recall_fscore_support(y_true_idx, y_pred_idx, average="macro", zero_division=0)

    try:
        auc_ovr = float(roc_auc_score(y_true, y_prob, multi_class="ovr"))
    except Exception:
        auc_ovr = float("nan")

    report_txt = classification_report(y_true_idx, y_pred_idx, target_names=class_names, zero_division=0)
    report_dict = classification_report(y_true_idx, y_pred_idx, target_names=class_names, zero_division=0, output_dict=True)
    cm = confusion_matrix(y_true_idx, y_pred_idx)

    return {
        "loss": float(loss),
        "accuracy": float(acc),
        "precision_macro": float(pr),
        "recall_macro": float(rc),
        "f1_macro": float(f1),
        "auc_ovr": float(auc_ovr),
        "report_txt": report_txt,
        "report_dict": report_dict,
        "cm": cm
    }

input_shape = (img_size, img_size, 3)
medinet_xg = create_MediNet_XG(input_shape, num_classes, width_multiplier=0.5)
medinet_xg.compile(optimizer=SGD(learning_rate=0.01, momentum=0.8, nesterov=True),
                   loss="categorical_crossentropy", metrics=["accuracy"])

history_medinet = medinet_xg.fit(
    train_dataset,
    epochs=300,
    validation_data=val_dataset,
    callbacks=get_callbacks("MediNet_XG"),
    verbose=0
)

medinet_xg.save(f"{OUTPUT_DIR}/MediNet_XG_final.h5")

plt.figure(figsize=(10,4))
plt.subplot(1,2,1); plt.plot(history_medinet.history["accuracy"]); plt.plot(history_medinet.history["val_accuracy"])
plt.title("MediNet_XG Accuracy"); plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.legend(["train","val"])
plt.subplot(1,2,2); plt.plot(history_medinet.history["loss"]); plt.plot(history_medinet.history["val_loss"])
plt.title("MediNet_XG Loss"); plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(["train","val"])
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/curves_MediNet_XG.png", dpi=300)
plt.show()
plt.close()

eval_medinet = evaluate_model(medinet_xg, test_dataset)

with open(f"{OUTPUT_DIR}/classification_report_MediNet_XG.txt", "w") as f:
    f.write(eval_medinet["report_txt"])
with open(f"{OUTPUT_DIR}/classification_report_MediNet_XG.json", "w") as f:
    json.dump(eval_medinet["report_dict"], f, indent=2)
with open(f"{OUTPUT_DIR}/metrics_MediNet_XG.json", "w") as f:
    json.dump({k: eval_medinet[k] for k in ["loss","accuracy","precision_macro","recall_macro","f1_macro","auc_ovr"]}, f, indent=2)

cm = eval_medinet["cm"]
plt.figure(figsize=(12,10))
sns.heatmap(cm, cmap="Blues", xticklabels=class_names, yticklabels=class_names, cbar=True)
plt.xticks(rotation=45, ha="right", fontsize=8)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix_MediNet_XG.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

eff_path = f"{OUTPUT_DIR}/MediNet_XG_final.h5"
params = int(medinet_xg.count_params())
size_kb = os.path.getsize(eff_path) / 1024

start = time.time()
n = 0
for x, _ in test_dataset.take(5):
    _ = medinet_xg.predict(x, verbose=0)
    n += x.shape[0]
ms_per_img = ((time.time() - start) / max(1, n)) * 1000

with open(f"{OUTPUT_DIR}/efficiency_MediNet_XG.json", "w") as f:
    json.dump({"params": params, "size_kb": float(size_kb), "ms_per_image": float(ms_per_img)}, f, indent=2)
