## Training and Validation plots

In [1]:
import json
from pathlib import Path
import matplotlib.pyplot as plt

### Loads and parses training history data from JSON files

In [None]:
def load_model_data(model_json_dict):
    model_data = {}

    for model_name, json_path_str in model_json_dict.items():
        json_path = Path(json_path_str)

        if json_path.is_file():
            try:
                with json_path.open("r", encoding="utf-8") as file:
                    data = json.load(file)
                    training_history = data.get("training_history", {})
                    loss = training_history.get("loss", [])
                    val_loss = training_history.get("val_loss", [])

                    if loss and val_loss:
                        model_data[model_name] = {
                            "loss": loss,
                            "val_loss": val_loss
                        }
                    else:
                        print(f"Warning value missing in {json_path_str}")
            except (json.JSONDecodeError, UnicodeDecodeError) as e:
                print(f"Error loading {json_path_str}: {e}")
            except Exception as e:
                print(f"Unexpected error with {json_path_str}: {e}")
        else:
            print(f"File {json_path_str} does not exist")

    return model_data

### Plots training and validation loss curves

In [None]:
def plot_loss_metrics(model_data):
    if not model_data:
        print("No data to plot")
        return
    first_model = next(iter(model_data))
    epochs = range(1, len(model_data[first_model]["loss"]) + 1)

    colors = plt.cm.tab10.colors

    # Plot Training Loss
    plt.figure(figsize=(12, 6))
    for i, (model_name, metrics) in enumerate(model_data.items()):
        plt.plot(
            epochs,
            metrics["loss"],
            label=model_name,
            color=colors[i % len(colors)],
            linestyle='-'
        )
    plt.title("Training Loss Comparison", fontsize=16)
    plt.xlabel("Epochs", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # Plot Validation Loss
    plt.figure(figsize=(12, 6))
    for i, (model_name, metrics) in enumerate(model_data.items()):
        plt.plot(
            epochs,
            metrics["val_loss"],
            label=model_name,
            color=colors[i % len(colors)],
            linestyle='--'
        )
    plt.title("Validation Loss Comparison", fontsize=16)
    plt.xlabel("Epochs", fontsize=14)
    plt.ylabel("Validation Loss", fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

### Coordinates the loading of model data and visualizes loss metrics for comparison

In [2]:
def analyze_models(model_json_dict):
    model_data = load_model_data(model_json_dict)
    if not model_data:
        print("No valid model data found")
        return

    plot_loss_metrics(model_data)

### Example usage

In [None]:
model_json_dict = {
    "Model 1": "models/model_1.json",
    "Model 2": "models/model_2.json"
}

In [None]:
analyze_models(model_json_dict)