In [None]:
%load_ext autoreload
%autoreload 2
#os.environ["OMP_NUM_THREADS"] = str(os.cpu_count() // 2)
#os.environ["GOMP_CPU_AFFINITY"] = "granularity=core,compact"
from Scripts.model import *
from Scripts.loss import *
from Scripts.results_manager import *
from Scripts.plots import *
from Scripts.dataset import *
from Scripts.trainer import *
from Scripts.inference import *
from Scripts.Onnx_Class import *
from Scripts.lr_finder import *
from Scripts.generate_configs import *
from Scripts.excecute import *
from Scripts.upload_summaries import *
from Scripts.quantize import *

In [None]:
config_path = "Configs"

In [None]:
training_selected_class(config_path, 'grid')
# metrics_to_db()

In [None]:
training_all_classes(config_path)
metrics_to_db()

In [None]:
training_run_folder = "Training_Runs"
inference_output_dir = "Inference_Runs"

inference_model(training_run_folder, inference_output_dir)

In [None]:
base_path = "timm_training_runs"

config_paths_all = glob.glob(os.path.join(
    base_path, "**", "*.yaml"), recursive=True)
summary_metrics_paths_all = glob.glob(os.path.join(
    base_path, "**", "summary_metrics.json"), recursive=True)
best_student_weight_paths_all = glob.glob(os.path.join(
    base_path, "**", "*best_student.pth"), recursive=True)

for configs, summary_metrics, best_student_weights in zip(config_paths_all, summary_metrics_paths_all, best_student_weight_paths_all):
    config = load_config(configs)
    summary_metric = load_json(summary_metrics)

In [None]:
inference_output_dir = 'quantized_inference_results'
device = torch.device('cpu')

quantized_weights_paths = glob.glob(os.path.join(
    'quantized_models', '**', 'quantized_model.pth'), recursive=True)

print(
    f"Führe Inferenz für {len(quantized_weights_paths)} quantisierte Modelle aus...")

for weight_path in quantized_weights_paths:
    dir_path = os.path.dirname(weight_path)
    model_name = Path(weight_path).parent.parent.name
    print(f"Modell aus {model_name} wird verwendet...")

    yaml_filename = None
    for file in os.listdir(dir_path):
        if file.endswith('.yaml'):
            yaml_filename = file
            break

    if yaml_filename is None:
        print(
            f"Keine YAML-Konfigurationsdatei im Verzeichnis {dir_path} gefunden.")
        continue

    json_path = os.path.join(dir_path, 'summary_metric.json')
    yaml_path = os.path.join(dir_path, yaml_filename)

    if not os.path.exists(json_path):
        print(
            f"Keine JSON-Zusammenfassungsdatei im Verzeichnis {dir_path} gefunden.")
        continue

    config = load_config(yaml_path)
    summary_data = load_json(json_path)
    training_id = summary_data.get('training_id', 'quantized_run')
    print("Lade Modelle fuer die Quantisierung...")
    inference_model = STFPM(
        architecture=config['model']['architecture'],
        layers=config['model']['layers'],
        quantize=False
    ).to(device).eval()
    stem_model = inference_model.stem_model
    student_model_to_quantize = inference_model.student_model.to(device).eval()
    qconfig_mapping = get_default_qconfig_mapping('fbgemm')
    example_inputs = (stem_model(torch.randn(
        1, 3, config['dataset']['img_size'], config['dataset']['img_size']
    )),)
    print("Bereite das Studenten-Modell fuer die Quantisierung vor...")
    prepared_model = quantize_fx.prepare_fx(
        student_model_to_quantize, qconfig_mapping, example_inputs
    )
    print("Konvertiere das vorbereitete Modell in ein quantisiertes Modell...")
    quantized_student_model = quantize_fx.convert_fx(prepared_model)
    print("Lade die quantisierten Gewichte in das Modell...")
    quantized_student_model.load_state_dict(
        torch.load(weight_path, map_location=device)
    )
    quantized_student_model.eval()

    inference_model.student_model = quantized_student_model.to(device)

    try:
        test_set = MVTecDataset(
            img_size=config['dataset']['img_size'],
            base_path=config['dataset']['base_path'],
            cls=config['dataset']['class'],
            mode='test',
            download_if_missing=False
        )
        print("Lade den gesamten Test-Datensatz in den Arbeitsspeicher...")
        memory_cache = [test_set[i] for i in tqdm(range(len(test_set)))]
        test_loader = DataLoader(
            memory_cache,
            batch_size=config['dataloader']['batch_size'],
            shuffle=False
        )
    except Exception as e:
        print(f"Fehler beim Laden des Test-Datensatzes für {yaml_path}: {e}")
        continue

    infer = Inference(
        model=inference_model,
        test_loader=test_loader,
        config=config,
        output_dir=inference_output_dir,
        path_to_student_weight=None,
        trainings_id=training_id,
        inferenz=True
    )

    print(f"Starte Inferenz für Konfiguration: {yaml_path}...")
    auroc_score, total_inference_time = infer.evaluate_quantized_loaded_model()
    infer.create_inference_summary(
        summary_data, auroc_score, total_inference_time)
    print(
        f"Inferenz abgeschlossen. AUROC: {auroc_score:.4f}, Zeit: {total_inference_time:.4f}s.")

    # infer.generate_heatmaps_from_saved_maps()

print("\n--- Alle Inferenzläufe abgeschlossen. ---")

In [None]:
import onnxruntime as ort
import numpy as np
from PIL import Image
import glob
import os
import time  # Hinzugefügt: Modul für die Zeitmessung
import matplotlib.pyplot as plt

# 1. ONNX-Modell laden
onnx_model_path = r"onnx_models\STFPM_bottle_mobilenetv4_conv_large.onnx"
sess = ort.InferenceSession(onnx_model_path)

image_path = glob.glob(os.path.join(
    r'Images\bottle\test\broken_large', '*.png'))
for path in image_path:
    image = Image.open(path).convert("RGB")

    img_size = 256
    image = image.resize((img_size, img_size))

    input_data = np.array(image, dtype=np.uint8)
    input_data = np.expand_dims(input_data, axis=0)

    input_name = sess.get_inputs()[0].name

    start_time = time.perf_counter()

    outputs = sess.run(None, {input_name: input_data})

    end_time = time.perf_counter()
    inference_time = end_time - start_time

    anomaly_map = outputs[0]
    anomaly_score = outputs[1]

    print(f"Anomalie-Score für das Bild: {anomaly_score[0]}")
    print(f"Inferenzzeit: {inference_time:.4f} Sekunden")

    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    # 2. Originalbild anzeigen
    axes[0].imshow(image)
    axes[0].set_title("Originalbild")
    axes[0].axis('off')

    # 3. Heatmap über das Originalbild legen
    axes[1].imshow(image)
    heatmap = axes[1].imshow(np.squeeze(anomaly_map), cmap='jet', alpha=0.5)
    axes[1].set_title("Anomalie-Heatmap")
    axes[1].axis('off')

    # 4. Farbbalken für die Heatmap hinzufügen
    fig.colorbar(heatmap, ax=axes[1], fraction=0.046, pad=0.04)

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    # 5. Plot anzeigen oder speichern
    plt.suptitle("Vergleich: Originalbild vs. Anomalie-Heatmap")
    plt.show()