# Library imports, setup

In [None]:
#if you change a file, you dont have to restart the kernel
%load_ext autoreload
%autoreload 2

In [None]:
from data import load_metadata, visualize_data, make_dataset
from model import build_multitask_model
from score_metrics import get_scores

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# check tf version
print(tf.__version__)

gpus = tf.config.list_physical_devices('GPU')
for device in gpus:
    tf.config.experimental.set_memory_growth(device, True)
    print(f"Found GPU {device.name}, and set memory growth to True.")

# Loading data

In [None]:
image_metadata, species_metadata = load_metadata()
NUM_SPECIES = len(species_metadata)

# Visualizing data

In [None]:
#in data.py
#visualize_data(image_metadata)

Loading python images from folder

# Building model

In [None]:
import tensorflow as tf
import keras

In [None]:
from model import build_multitask_model
from score_metrics import get_scores

from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight

IMAGE_RESOLUTION=224
#IMAGE_RESOLUTION=544


NUM_FOLDS = 3

from data import make_batches



X_paths = image_metadata["image_path"].values
y_species = image_metadata["encoded_id"].values

skf = StratifiedKFold(
    n_splits=NUM_FOLDS,
    shuffle=True,
    random_state=42,
)

fold_metrics = []


best_fold_idx = None
best_macro_f1 = -np.inf

all_y_species_true = []
all_y_species_pred = []
all_y_venom_true = []
all_y_venom_pred = []

In [None]:
for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_paths, y_species), start=1):
    print(f"\n===== FOLD {fold_idx}/{NUM_FOLDS} =====")

    train_info = image_metadata.iloc[train_idx].copy()
    val_info   = image_metadata.iloc[val_idx].copy()

    # --- class weight csak a trainre számolva ---
    species_classes = np.unique(train_info["encoded_id"])
    species_cw = compute_class_weight(
        class_weight="balanced",
        classes=species_classes,
        y=train_info["encoded_id"],
    )
    species_cw_dict = {int(c): w for c, w in zip(species_classes, species_cw)}

    species_weight_vec = tf.constant(
        [species_cw_dict[i] for i in range(len(species_cw_dict))],
        dtype=tf.float32,
    )

    # --- tf.data datasetek ---
    train_dataset = make_batches(
        train_info,
        IMAGE_RESOLUTION,
        species_weight_vec=species_weight_vec,
    )
    val_dataset = make_batches(
        val_info,
        IMAGE_RESOLUTION,
        species_weight_vec=None,
    )

In [None]:
for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_paths, y_species), start=1):
    print(f"\n===== FOLD {fold_idx}/{NUM_FOLDS} =====")

    train_info = image_metadata.iloc[train_idx].copy()
    val_info   = image_metadata.iloc[val_idx].copy()

    # --- class weight csak a trainre számolva ---
    species_classes = np.unique(train_info["encoded_id"])
    species_cw = compute_class_weight(
        class_weight="balanced",
        classes=species_classes,
        y=train_info["encoded_id"],
    )
    species_cw_dict = {int(c): w for c, w in zip(species_classes, species_cw)}

    species_weight_vec = tf.constant(
        [species_cw_dict[i] for i in range(len(species_cw_dict))],
        dtype=tf.float32,
    )

    # --- tf.data datasetek ---
    train_dataset = make_batches(
        train_info,
        IMAGE_RESOLUTION,
        species_weight_vec=species_weight_vec,
    )

    val_dataset = make_batches(
        val_info,
        IMAGE_RESOLUTION,
        species_weight_vec=None,
    )

    # --- modell: minden foldban újraépítjük ---
    model = build_multitask_model(
        num_species=NUM_SPECIES,
        image_resolution=IMAGE_RESOLUTION,
    )

    lr = 5e-4  # EfficientNetB0-hoz alacsony LR

    #stage 1: backbone freeze
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss={
            "species": "sparse_categorical_crossentropy", 
            "venom": "binary_crossentropy"
        },
        loss_weights={"species": 1.0, "venom": 0.5},
        metrics={
            "species": "accuracy", 
            "venom": "accuracy"
        },
    )

    # --- callbackek fold-specifikus checkpointtal ---
    checkpoint_cb = keras.callbacks.ModelCheckpoint(
        f"best_model_fold{fold_idx}.keras",
        monitor="val_loss",
        save_best_only=True,
        save_weights_only=False,
        verbose=1,
    )

    early_stop_cb = keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=6,
        restore_best_weights=True,
        verbose=1,
    )

    reduce_lr_cb = keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.3,
        patience=3,
        min_lr=1e-6,
        verbose=1,
    )
    n_epochs = 20

    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=n_epochs,
        callbacks=[checkpoint_cb, early_stop_cb, reduce_lr_cb],
    )
    #stage 2: unfreeze backbone
    model.trainable = True
    fine_tune_lr = 1e-5 #ehhez érdemes kisebb LR-t használni

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=fine_tune_lr),
        loss={
            "species": "sparse_categorical_crossentropy", 
            "venom": "binary_crossentropy"                
        },
        loss_weights={"species": 0.8, "venom": 1.2},
        metrics={
            "species": "accuracy", 
            "venom": "accuracy"
        },
    )

    ft_checkpoint_cb = keras.callbacks.ModelCheckpoint(
        f"best_model_fold{fold_idx}_finetuned.keras",
        monitor="val_loss",
        save_best_only=True,
        save_weights_only=False,
        verbose=1,
    )
    
    history_fine = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=40, 
        callbacks=[ft_checkpoint_cb, ft_early_stop_cb],
    )

    ft_early_stop_cb = keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=4,
        restore_best_weights=True,
        verbose=1,
    )

    # --- saját metrikák foldonként, get_scores-szal ---
    metrics_fold = get_scores(
        model,
        image_metadata=image_metadata,
        test_dataset=val_dataset,
        venom_threshold=0.5,
    )

    fold_metrics.append(metrics_fold)

    if metrics_fold["macro_f1"] > best_macro_f1:
        best_macro_f1 = metrics_fold["macro_f1"]
        best_fold_idx = fold_idx

    
    all_y_species_true.append(metrics_fold["y_species_true"])
    all_y_species_pred.append(metrics_fold["y_species_pred"])
    all_y_venom_true.append(metrics_fold["y_venom_true"])
    all_y_venom_pred.append(metrics_fold["y_venom_pred"])


In [None]:
# --- aggregált CV-eredmények (átlag metrikák + összesített predikciók) ---

best_model = keras.models.load_model(
    f"best_model_fold{best_fold_idx}.keras",
    custom_objects={"SoftF1Loss": SoftF1Loss},
)
best_model.save("final_cv_model.keras")

results_own_metrics = {
    "species_accuracy": np.mean([m["species_accuracy"] for m in fold_metrics]),
    "macro_f1": np.mean([m["macro_f1"] for m in fold_metrics]),
    "venom_accuracy": np.mean([m["venom_accuracy"] for m in fold_metrics]),
    "venom_weighted_species_accuracy": np.mean(
        [m["venom_weighted_species_accuracy"] for m in fold_metrics]
    ),
    "y_species_true": np.concatenate(all_y_species_true, axis=0),
    "y_species_pred": np.concatenate(all_y_species_pred, axis=0),
    "y_venom_true": np.concatenate(all_y_venom_true, axis=0),
    "y_venom_pred": np.concatenate(all_y_venom_pred, axis=0),
}

print("\n=== Átlagolt keresztvalidációs eredmények ===")
print(f"Species accuracy (val): {results_own_metrics['species_accuracy']:.4f}")
print(f"Macro-F1 (species, val): {results_own_metrics['macro_f1']:.4f}")
print(f"Venom accuracy (val): {results_own_metrics['venom_accuracy']:.4f}")
print(
    "Venom-weighted species accuracy (val): "
    f"{results_own_metrics['venom_weighted_species_accuracy']:.4f}"
)

# Example results

In [None]:
example_results_from_dataset(model, val_dataset, species_metadata, n_examples=5)

# Calculating scoring metrics

Function to tell if the species is venomous or not, based on encoded_id

# Plotting mistakes

In [None]:
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import numpy as np

def plot_per_class_recall_f1(results):
    y_true = results["y_species_true"]
    y_pred = results["y_species_pred"]

    report = classification_report(y_true, y_pred, output_dict=True)
    class_ids = sorted([int(c) for c in report.keys() if c.isdigit()])

    recalls = [report[str(c)]["recall"] for c in class_ids]
    f1s     = [report[str(c)]["f1-score"] for c in class_ids]

    plt.figure(figsize=(20,6))
    plt.bar(class_ids, recalls)
    plt.title("Per-Class Recall")
    plt.xlabel("Class ID")
    plt.ylabel("Recall")
    plt.show()

    plt.figure(figsize=(20,6))
    plt.bar(class_ids, f1s)
    plt.title("Per-Class F1-Score")
    plt.xlabel("Class ID")
    plt.ylabel("F1")
    plt.show()
import numpy as np

def get_top_confused_pairs(cm, species_names, top_k=20):
    """
    Returns the top most confused class pairs from a confusion matrix.
    
    Args:
        cm: confusion matrix (shape NxN)
        species_names: list mapping class_id -> species name
        top_k: how many confused pairs to return

    Returns:
        A list of dicts with:
            true_id, pred_id, true_name, pred_name, count
    """
    cm_no_diag = cm.copy()
    np.fill_diagonal(cm_no_diag, 0)  # remove correct predictions

    confusions = []

    # Find all non-zero misclassifications
    for true_cls in range(cm_no_diag.shape[0]):
        for pred_cls in range(cm_no_diag.shape[1]):
            count = cm_no_diag[true_cls, pred_cls]
            if count > 0:
                confusions.append((true_cls, pred_cls, count))

    # Sort by count descending
    confusions.sort(key=lambda x: x[2], reverse=True)

    # Build output list
    results = []
    for true_id, pred_id, count in confusions[:top_k]:
        results.append({
            "true_id": true_id,
            "pred_id": pred_id,
            "true_name": species_names[true_id],
            "pred_name": species_names[pred_id],
            "count": count
        })

    return results


In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

def plot_confusion_matrix(results):
    y_true = results["y_species_true"]
    y_pred = results["y_species_pred"]

    cm = confusion_matrix(y_true, y_pred, normalize='true')

    plt.figure(figsize=(10, 10))
    sns.heatmap(cm, cmap='Blues')
    plt.title("Normalized Confusion Matrix")
    plt.show()


In [None]:
plot_per_class_recall_f1(results_own_metrics)

In [None]:
plot_confusion_matrix(results_own_metrics)