# visualization

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split

# Set matplotlib font to support superscript minus
plt.rcParams['font.family'] = 'DejaVu Sans'

# =======================================================
# Directories
# =======================================================
data_dir = 'data'
visualizations_dir = os.path.join(data_dir, 'visualizations')
experiment_dir = os.path.join('experiments', 'experiment_20250914_220130')  # Fixed timestamp
results_dir = os.path.join(experiment_dir, 'results')
labels_dir = os.path.join(data_dir, 'labels')

# Create visualizations directory if it doesn't exist
os.makedirs(visualizations_dir, exist_ok=True)

# =======================================================
# Load test labels
# =======================================================
try:
    labels_df = pd.read_csv(os.path.join(labels_dir, 'labels.csv'))
    y = labels_df['label'].values
    _, y_test = train_test_split(y, test_size=0.2, random_state=42)
    print("Test labels loaded successfully!")
except Exception as e:
    print("Error loading labels:", str(e))
    raise

# Convert ethanol labels to methanol (100 - ethanol%)
y_test_methanol = 100 - y_test * 10  # Assuming y_test is in range 0-10 for 0-100%

# =======================================================
# Load predictions
# =======================================================
try:
    y_pred_1d_densenet = np.load(os.path.join(results_dir, 'y_pred_1d_densenet.npy'))
    y_pred_2d_densenet = np.load(os.path.join(results_dir, 'y_pred_2d_densenet.npy'))
    y_pred_1d_resnet = np.load(os.path.join(results_dir, 'y_pred_1d_resnet.npy'))
    y_pred_2d_resnet = np.load(os.path.join(results_dir, 'y_pred_2d_resnet.npy'))
    print("All prediction files loaded successfully!")
except FileNotFoundError as e:
    print("Error: One or more prediction files not found:", str(e))
    print("Please ensure all files (y_pred_1d_densenet.npy, y_pred_2d_densenet.npy, "
          "y_pred_1d_resnet.npy, y_pred_2d_resnet.npy) exist in", results_dir)
    raise

# Convert predictions to labels
y_pred_1d_densenet_labels = np.argmax(y_pred_1d_densenet, axis=1)
y_pred_2d_densenet_labels = np.argmax(y_pred_2d_densenet, axis=1)
y_pred_1d_resnet_labels = np.argmax(y_pred_1d_resnet, axis=1)
y_pred_2d_resnet_labels = np.argmax(y_pred_2d_resnet, axis=1)

# Convert predicted ethanol labels to methanol (100 - ethanol%)
y_pred_1d_densenet_labels_methanol = 100 - y_pred_1d_densenet_labels * 10
y_pred_2d_densenet_labels_methanol = 100 - y_pred_2d_densenet_labels * 10
y_pred_1d_resnet_labels_methanol = 100 - y_pred_1d_resnet_labels * 10
y_pred_2d_resnet_labels_methanol = 100 - y_pred_2d_resnet_labels * 10

# =======================================================
# Normalized confusion matrix function
# =======================================================
def calculate_normalized_confusion_matrix(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    cm_normalized = np.nan_to_num(cm_normalized, nan=0.0)
    return cm_normalized

def plot_normalized_confusion_matrix(y_true, y_pred, title, filename):
    cm_normalized = calculate_normalized_confusion_matrix(y_true, y_pred)
    labels = [f"{i}%" for i in range(0, 101, 10)][::-1]  # Reversed for methanol (100% to 0%)

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm_normalized, annot=True, fmt='.1f', cmap='Blues',
                xticklabels=labels, yticklabels=labels, cbar_kws={'label': '%'},
                annot_kws={'size': 14},  # Increase annotation font size
                square=True)  # Ensure square cells
    plt.xlabel('Predicted Methanol Ratio', fontsize=18, weight='bold')
    plt.ylabel('True Methanol Ratio', fontsize=18, weight='bold')
    plt.title(title, fontsize=20, weight='bold', pad=20)
    plt.xticks(fontsize=14, rotation=45, ha='right')
    plt.yticks(fontsize=14, rotation=0)
    cbar = plt.gcf().axes[-1]
    cbar.set_ylabel('%', fontsize=16, weight='bold')
    cbar.tick_params(labelsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(visualizations_dir, filename), dpi=300, bbox_inches='tight')
    plt.close()

# =======================================================
# Plot normalized confusion matrices for all models
# =======================================================
# DenseNet 1D
plot_normalized_confusion_matrix(
    y_test_methanol,
    y_pred_1d_densenet_labels_methanol,
    "Normalized Confusion Matrix for DenseNet 1D (%)",
    "normalized_confusion_matrix_densenet_1d.png"
)

# DenseNet 2D
plot_normalized_confusion_matrix(
    y_test_methanol,
    y_pred_2d_densenet_labels_methanol,
    "Normalized Confusion Matrix for DenseNet 2D (%)",
    "normalized_confusion_matrix_densenet_2d.png"
)

# ResNet 1D
plot_normalized_confusion_matrix(
    y_test_methanol,
    y_pred_1d_resnet_labels_methanol,
    "Normalized Confusion Matrix for ResNet 1D (%)",
    "normalized_confusion_matrix_resnet_1d.png"
)

# ResNet 2D
plot_normalized_confusion_matrix(
    y_test_methanol,
    y_pred_2d_resnet_labels_methanol,
    "Normalized Confusion Matrix for ResNet 2D (%)",
    "normalized_confusion_matrix_resnet_2d.png"
)

# =======================================================
# Print normalized confusion matrix values
# =======================================================
models = [
    ('DenseNet 1D', y_pred_1d_densenet_labels_methanol),
    ('DenseNet 2D', y_pred_2d_densenet_labels_methanol),
    ('ResNet 1D', y_pred_1d_resnet_labels_methanol),
    ('ResNet 2D', y_pred_2d_resnet_labels_methanol)
]

for model_name, y_pred_labels in models:
    cm_normalized = calculate_normalized_confusion_matrix(y_test_methanol, y_pred_labels)
    print(f"\nNormalized Confusion Matrix for {model_name} (%):")
    for i, row in enumerate(cm_normalized):
        print(f"True {100 - i*10}% Methanol:")
        for j, value in enumerate(row):
            print(f"  Predicted {100 - j*10}%: {value:.1f}%")