In [None]:
import json
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
from utils.histogram import hsi_to_rgb
import tensorflow as tf

In [None]:
model_path = "./results/lineRWKV_2026_01_14-00_33_25/logs/training_history.json"

In [None]:
model_history = model_path
if not os.path.exists(model_history):
    print(f"Model history file {model_history} does not exist.")
    raise FileNotFoundError(f"Model history file {model_history} does not exist.")

history = {}
with open(model_history, "r") as f:
    history = json.load(f)

In [None]:
print(history.keys())


In [None]:
for key in history:
    plt.plot(history[key], label=key)
    plt.xlabel("Epoch")
    plt.ylabel("Value")
    plt.title(f"{key} over Epochs")
    plt.legend()
    plt.show()

In [None]:
data_sorted = {}
for key in history:
    if key == "epoch_times":
        continue
    if key.split("_")[1] not in data_sorted:
        data_sorted[key.split("_")[1]] = {}
    if key.startswith("val"):
        data_sorted[key.split("_")[1]]["val"] = history[key]
    else:
        data_sorted[key.split("_")[1]]["train"] = history[key]

translation_dict = {
    "loss": "Loss _0",
    "acc": "Accuracy _1",
    "sa": "SA _5",
    "psnr": "PSNR _6",
    "ssim": "SSIM _4",
    "mse": "MSE _1",
    "mae": "MAE _2",
    "rmse": "RMSE _3",
}

def sorting_key_func(key):
    if key in translation_dict:
        if key == "loss":
            return 0
        else:
            return int(translation_dict[key.split].split("_")[1])

    return len(translation_dict) + 1


In [None]:
# Format data to paste into latex tikzpicture graph

template = """
\\begin{figure}
\\centering
"""


for idx, key in enumerate(sorted(data_sorted.keys(), key=lambda x: x.lower())):
    if key not in translation_dict:
        continue
    subfig = f"""
    \\begin{{subfigure}}{{0.4\\textwidth}}
        \\centering
        \\begin{{tikzpicture}}
            \\begin{{axis}}[
                    xlabel={{Training Epoch}},
                    ylabel={{Value}},
                    title={{{translation_dict[key]} over Epochs}},
                    grid=major,
                    y tick label style={{
                            /pgf/number format/.cd,
                            fixed,
                            fixed zerofill,
                            precision=2,
                            /tikz/.cd
                        }},
                    x tick label style={{
                            /pgf/number format/.cd,
                            fixed,
                            fixed zerofill,
                            precision=0,
                            /tikz/.cd
                        }}
                ]
    """
    data = data_sorted[key]["val"]
    formatted_data_val = "\n ".join([f"({i}, {v})" for i, v in enumerate(data)])
    subfig += f"""
        \\addplot [color=blue, mark=o] coordinates {{
            {formatted_data_val}
        }};
        \\addlegendentry{{Validation}}
    """
    data = data_sorted[key]["train"]
    formatted_data_train = "\n ".join([f"({i}, {v})"
    for i, v in enumerate(data)])
    subfig += f"""
        \\addplot [color=orange, mark=x] coordinates {{
            {formatted_data_train}
        }};
        \\addlegendentry{{Training}}
    
        \\end{{axis}}
        \\end{{tikzpicture}}
        \\caption{{CHANGE ME!!!}}
        \\label{{fig:{key}}}
    \\end{{subfigure}}
    \\hfill
    """
    template += subfig
template += """
\\caption{Training and Validation Metrics over Epochs}
\\label{fig:training_metrics}
\\end{figure}
"""

print(template)

    # print(f"{key} data for TikZ: {formatted_data}")

In [None]:
# Load small_seg model for visualization
from segmentation.small_seg import small_segmenter

# Create model with same config as training
seg_model = small_segmenter(
    input_shape=(128, 128, 202, 1),
    num_classes=4,
    base_filters=8,
    depth=3,
    dropout_rate=0.1
)

# Build model with dummy input
dummy_input = tf.random.normal([1, 128, 128, 202, 1])
_ = seg_model(dummy_input, training=False)

# Load best weights
weights_path = "./output/models/small_seg_best.weights.h5"
seg_model.load_weights(weights_path)
print(f"Model loaded from: {weights_path}")

# Display model summary
seg_model.summary()

In [None]:
# Visualizing segmentation predictions vs ground truth
from TFDataloader.TFdataloader import TFHySpecNetLoader
from matplotlib.colors import ListedColormap
import random

# Create constant colormap: 0=Black - None, 1=Blue - Land, 2=Green - Water, 3=Red - Background
class_colors = ['black', 'blue', 'green', 'red']
constant_cmap = ListedColormap(class_colors)

def visualize_segmentation_comparison(image, gt_mask, pred_mask, save_dir=None, sample_num=None):
    """Save images only if save_dir and sample_num provided. No display, no borders, no titles."""
    
    if not (save_dir and sample_num):
        return
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if os.path.exists(os.path.join(save_dir, f"sample_{sample_num:02d}_input.png")):
        print(f"Sample {sample_num:02d} already exists. Skipping...")
        return
    # Input Image
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(hsi_to_rgb(image.numpy()))
    ax.axis("off")
    input_path = os.path.join(save_dir, f"sample_{sample_num:02d}_input.png")
    fig.savefig(input_path, dpi=150, bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    print(f"Saved: {input_path}")
    
    # Ground Truth Mask
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(gt_mask, cmap=constant_cmap, vmin=0, vmax=3)
    ax.axis("off")
    gt_path = os.path.join(save_dir, f"sample_{sample_num:02d}_gt.png")
    fig.savefig(gt_path, dpi=150, bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    print(f"Saved: {gt_path}")
    
    # Model Prediction
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(pred_mask, cmap=constant_cmap, vmin=0, vmax=3)
    ax.axis("off")
    pred_path = os.path.join(save_dir, f"sample_{sample_num:02d}_pred.png")
    fig.savefig(pred_path, dpi=150, bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    print(f"Saved: {pred_path}")


# Load test data
data_loader = TFHySpecNetLoader(
    "./test_data/hyspecnet-11k/",
    batch_size=1, 
    data_mode=3, 
    split="test"
)
data_set = data_loader._create_dataset(1000)

# Convert dataset to list for random sampling
print("Loading dataset samples...")
dataset_samples = []
for images, masks in data_set:
    # Only keep samples with multiple classes
    counts = tf.unique(tf.reshape(masks[0], [-1])).y
    if counts.shape[0] >= 2:
        dataset_samples.append((images, masks))

print(f"Found {len(dataset_samples)} multi-class samples")

# Randomly select samples to visualize
num_to_show = 40

random_samples = random.sample(dataset_samples, min(num_to_show, len(dataset_samples)))

for idx, (images, masks) in enumerate(random_samples):
    image = images[0].numpy().squeeze()
    image_reshuffled = tf.transpose(image, [0, 1, 2])
    
    # Run model prediction
    pred_probs = seg_model(images, training=False)  # (1, 128, 128, 4)
    pred_mask = tf.argmax(pred_probs[0], axis=-1).numpy()  # (128, 128)
    gt_mask = masks[0].numpy()  # (128, 128)
    
    # Visualize comparison
    counts = tf.unique(tf.reshape(masks[0], [-1])).y
    print(f"Sample {idx + 1} - Unique GT classes: {counts.numpy()}")
    visualize_segmentation_comparison(image_reshuffled, gt_mask, pred_mask,
                                      save_dir="./Latex/graf", sample_num=idx + 1)

print(f"\nVisualizations complete. Showed {len(random_samples)} random samples.")