# Setup

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

import random
import datetime
import importlib
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
import matplotlib.pyplot as plt


from typing import List

# Matrixlib
from matrixkit import preconditioning as prec

# Modellib
import modellib.cnn
import modellib.train
import modellib.evaluate as eval

import modellib.io as io

# Load and transform matrix datasets

In [None]:
# Reading the datasets
train_bands, train_labels = io.read_from_hdf5("train_matrices_64_1600.h5")
val_bands, val_labels = io.read_from_hdf5("val_matrices_64_200.h5")
test_bands, test_labels = io.read_from_hdf5("test_matrices_64_200.h5")

# Printing shapes to verify
print(f"Train bands shape: {train_bands.shape}, Train labels shape: {train_labels.shape}")
print(f"Validation bands shape: {val_bands.shape}, Validation labels shape: {val_labels.shape}")
print(f"Test bands shape: {test_bands.shape}, Test labels shape: {test_labels.shape}")

# Convert to tensorflow datasets
train_dataset = tf.data.Dataset.from_tensor_slices((train_bands, train_labels))
val_dataset = tf.data.Dataset.from_tensor_slices((val_bands, val_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_bands, test_labels))

# Training
run tensorboard --logdir logs to launch tensorboard

In [None]:
# Create log dir
log_dir = "logs/cnn/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
print("Files in log directory:", os.listdir(log_dir))

# Create Learning Rate scheduler
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.00001,
    decay_steps=1000,
    decay_rate=0.9
)

# Define parameters
batch_size = 16
num_epochs = 200
input_shape = (21, 64, 1)
optimizer = tf.keras.optimizers.Nadam(learning_rate=lr_schedule)
class_weights = {0: 0.2, 1: 0.8}

# Create and Compile Model
model = modellib.cnn.create_compile_model_custom_loss(
    input_shape, 
    optimizer, 
    class_weights
)

# Start Training Loop
trained_model, train_losses, val_losses = modellib.train.train_model(
    model,
    train_dataset.batch(batch_size),  
    val_dataset.batch(batch_size),
    num_epochs,
    log_dir
)

In [None]:
# Plot losses
modellib.train.plot_losses(train_losses, val_losses)

# Evaluation

In [None]:
# Evaluate the model on the test set
evaluation_results = modellib.evaluate.evaluate_model(
    trained_model,
    test_dataset.batch(batch_size),
    class_weights
)

# Restore Best Model Weights

In [None]:
# Restore weights from best run
new_model = modellib.cnn.Baseline(input_shape)
new_model.build((None,) + input_shape)  # None represents the batch dimension
new_model.load_weights("cnn.weights.h5")

# print weights
print(f"Model Weights: {new_model.get_weights()}")

# Evaluate the restored model
new_results = modellib.evaluate.evaluate_model(
    new_model,
    test_dataset.batch(batch_size),
    class_weights
)

# Prediction

In [None]:
importlib.reload(eval)
test_predictions = new_model.predict(test_dataset.batch(8))

# Convert to Binary
threshold = 0.5
binary_predictions = (test_predictions >= threshold).astype(int)

# Create Preconditioner from Supervariable Blocking
## 1. Find supervariables (sharing the same sparsity pattern)
## 2- Extract indicator of block starts
## 3. Feed into preconditioner from predictions function

In [None]:
import matrixkit.block as blk
importlib.reload(blk)
importlib.reload(prec)
importlib.reload(eval)

block_starts = blk.supervariable_blocking(test_matrices, 32)
metrics_svblocking = modellib.evaluation.calculate_metrics(test_labels, block_starts)


# Create Preconditioner from predicted block starts
### 1. Form Blocks from Predictions
### 2. Invert Blocks (ensure non-singularity)
### 3. Create Block Diagonal Matrix



In [None]:
importlib.reload(prec)

# Define variables
A_original = test_matrices
A_flipped = np.array([prec.prepare_matrix(test_matrices[i], method='flip') for i in range(test_matrices.shape[0])])
A_shifted = np.array([prec.prepare_matrix(test_matrices[i], method='shift') for i in range(test_matrices.shape[0])])
A_minmax = np.array([prec.prepare_matrix(test_matrices[i], method='minmax') for i in range(test_matrices.shape[0])])

print("Preconditioner from True Block Starts")
precs_true = prec.block_jacobi_preconditioner_from_predictions(test_matrices, test_labels)
print("Preconditioner from Predicted Block Starts")
precs_cnn = prec.block_jacobi_preconditioner_from_predictions(test_matrices, binary_predictions)
print("Preconditioner from Supervariable Blocking")
precs_sv = prec.block_jacobi_preconditioner_from_predictions(test_matrices, block_starts)

b = np.ones((test_matrices.shape[0], test_matrices.shape[1]))

In [None]:
 # plot condition numbers 
plt.figure(figsize=(10, 6))
plt.plot([np.linalg.cond(A_original[i]) for i in range(A_original.shape[0])], label='Original Matrix')
plt.plot([np.linalg.cond(A_flipped[i]) for i in range(A_flipped.shape[0])], label='Prepared Matrix (flipped)')
plt.plot([np.linalg.cond(A_shifted[i]) for i in range(A_shifted.shape[0])], label='Prepared Matrix (shifted)')
plt.plot([np.linalg.cond(A_minmax[i]) for i in range(A_minmax.shape[0])], label='Prepared Matrix (minmax)')
plt.xlabel('Matrix Index')
plt.ylabel('Condition Number')
plt.title('Condition Numbers of Original and Prepared Matrices')
plt.legend()
plt.show()

# Run GMRES Solver:
 ### - Without Preconditioner 
    - With original Matrices
    - With flipped Matrices
    - With scaled Matrices
    - With shifted Matrices
 ### - With Preconditioner created from true block starts
    - With original Matrices
    - With flipped Matrices
    - With scaled Matrices
    - With shifted Matrices
    
 ### - With Preconditioner created from predicted block starts 
    - With original Matrices
    - With flipped Matrices
    - With scaled Matrices
    - With shifted Matrices
 
 ## - With Preconditioner created from Supervariable Blocking 
    - With original Matrices
    - With flipped Matrices
    - With scaled Matrices
    - With shifted Matrices
    

In [None]:
importlib.reload(prec)

# Solve without precodntitioner
input_matrices_list = [
    ('Original Matrices', A_original),
    ('Flipped Matrices', A_flipped),
    ('Scaled Matrices', A_minmax),
    ('Shifted Matrices', A_shifted)
]
preconditioners_list = [
    ('No Preconditioner', None),
    ('Prec from True Block Starts', precs_true),
    ('Prec from Predicted Block Starts', precs_cnn),
    ('Prec from Supervariable Blocking', precs_sv)
]

results = []

for input_name, input_matrices in input_matrices_list:
    for prec_name, preconditioner_source in preconditioners_list:
        print(f"Solving {input_name} with {prec_name}")
        x, info, iters, residuals = prec.solve_with_gmres_monitored(input_matrices, b, preconditioner_source)        
        
        results.append({
            'Matrix Type': input_name,
            'Preconditioner Type': prec_name,
            'Iterations': iters            
        })


In [None]:
# Unpack results for each combination of matrix type and preconditioner type for boxplot 
iters_no_prec_original = [result['Iterations'] for result in results if result['Matrix Type'] == 'Original Matrices' and result['Preconditioner Type'] == 'No Preconditioner']
iters_no_prec_flipped = [result['Iterations'] for result in results if result['Matrix Type'] == 'Flipped Matrices' and result['Preconditioner Type'] == 'No Preconditioner']
iters_no_prec_minmax = [result['Iterations'] for result in results if result['Matrix Type'] == 'Scaled Matrices' and result['Preconditioner Type'] == 'No Preconditioner']
iters_no_prec_shifted = [result['Iterations'] for result in results if result['Matrix Type'] == 'Shifted Matrices' and result['Preconditioner Type'] == 'No Preconditioner']
iters_prec_true_original = [result['Iterations'] for result in results if result['Matrix Type'] == 'Original Matrices' and result['Preconditioner Type'] == 'Prec from True Block Starts']
iters_prec_true_flipped = [result['Iterations'] for result in results if result['Matrix Type'] == 'Flipped Matrices' and result['Preconditioner Type'] == 'Prec from True Block Starts']
iters_prec_true_minmax = [result['Iterations'] for result in results if result['Matrix Type'] == 'Scaled Matrices' and result['Preconditioner Type'] == 'Prec from True Block Starts']
iters_prec_true_shifted = [result['Iterations'] for result in results if result['Matrix Type'] == 'Shifted Matrices' and result['Preconditioner Type'] == 'Prec from True Block Starts']
iters_prec_cnn_original = [result['Iterations'] for result in results if result['Matrix Type'] == 'Original Matrices' and result['Preconditioner Type'] == 'Prec from Predicted Block Starts']
iters_prec_cnn_flipped = [result['Iterations'] for result in results if result['Matrix Type'] == 'Flipped Matrices' and result['Preconditioner Type'] == 'Prec from Predicted Block Starts']
iters_prec_cnn_minmax = [result['Iterations'] for result in results if result['Matrix Type'] == 'Scaled Matrices' and result['Preconditioner Type'] == 'Prec from Predicted Block Starts']
iters_prec_cnn_shifted = [result['Iterations'] for result in results if result['Matrix Type'] == 'Shifted Matrices' and result['Preconditioner Type'] == 'Prec from Predicted Block Starts']
iters_prec_sv_original = [result['Iterations'] for result in results if result['Matrix Type'] == 'Original Matrices' and result['Preconditioner Type'] == 'Prec from Supervariable Blocking']
iters_prec_sv_flipped = [result['Iterations'] for result in results if result['Matrix Type'] == 'Flipped Matrices' and result['Preconditioner Type'] == 'Prec from Supervariable Blocking']
iters_prec_sv_minmax = [result['Iterations'] for result in results if result['Matrix Type'] == 'Scaled Matrices' and result['Preconditioner Type'] == 'Prec from Supervariable Blocking']
iters_prec_sv_shifted = [result['Iterations'] for result in results if result['Matrix Type'] == 'Shifted Matrices' and result['Preconditioner Type'] == 'Prec from Supervariable Blocking']

def prepare_data(data_list):
    if not data_list:  # If the list is empty
        return [0]  # Return a list with a single zero
    data_array = np.array(data_list)
    if data_array.ndim > 2:
        data_array = data_array.reshape(-1)  # Flatten to 1D
    elif data_array.ndim == 2:
        data_array = data_array.flatten()
    return data_array.tolist()

# Prepare all data lists
data_lists_1 = [
    iters_no_prec_original, iters_no_prec_flipped, iters_no_prec_minmax, iters_no_prec_shifted,
    iters_prec_true_original, iters_prec_true_flipped, iters_prec_true_minmax, iters_prec_true_shifted
]

data_lists_2 = [
    iters_prec_true_original, iters_prec_true_flipped, iters_prec_true_minmax, iters_prec_true_shifted,
    iters_prec_cnn_original, iters_prec_cnn_flipped, iters_prec_cnn_minmax, iters_prec_cnn_shifted,
    iters_prec_sv_original, iters_prec_sv_flipped, iters_prec_sv_minmax, iters_prec_sv_shifted
]

prepared_data_1 = [prepare_data(data) for data in data_lists_1]
prepared_data_2 = [prepare_data(data) for data in data_lists_2]

# Create two separate figures
plt.figure(figsize=(10,10))

# First boxplot: No preconditioner and True preconditioner
plt.subplot(2, 1, 1)
bp1 = plt.boxplot(
    prepared_data_1,
    tick_labels=[
        'Original\n(no prec)', 'Flipped\n(no prec)', 'Scaled\n(no prec)', 'Shifted\n(no prec)',
        'Original\n(true prec)', 'Flipped\n(true prec)', 'Scaled\n(true prec)', 'Shifted\n(true prec)'
    ]
)

plt.title('No Preconditioner vs True Preconditioner')
plt.ylabel('Number of Iterations')
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
plt.xticks(rotation=90)

for i, data in enumerate(prepared_data_1):
    x = np.random.normal(i+1, 0.04, size=len(data))
    plt.scatter(x, data, alpha=0.4, s=5, zorder=2)

# Second boxplot: All preconditioners
plt.subplot(2, 1, 2)
bp2 = plt.boxplot(
    prepared_data_2,
    tick_labels=[
        'Original\n(true prec)', 'Flipped\n(true prec)', 'Scaled\n(true prec)', 'Shifted\n(true prec)',
        'Original\n(modellib prec)', 'Flipped\n(modellib prec)', 'Scaled\n(modellib prec)', 'Shifted\n(modellib prec)',
        'Original\n(SV prec)', 'Flipped\n(SV prec)', 'Scaled\n(SV prec)', 'Shifted\n(SV prec)'
    ]
)

plt.title('Comparison of All Preconditioners')
plt.ylabel('Number of Iterations')
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
plt.xticks(rotation=90)

for i, data in enumerate(prepared_data_2):
    x = np.random.normal(i+1, 0.04, size=len(data))
    plt.scatter(x, data, alpha=0.4, s=5, zorder=2)

plt.tight_layout()
plt.show()

In [None]:
# Calculate statistics for iteration counts
def iter_stats(iters):
    return {        
        'min': np.min(iters),
        'max': np.max(iters),
        'mean': np.mean(iters),
        'median': np.median(iters)
    }

# print stats
results_df = pd.DataFrame(results)

In [None]:
# Create box plots
plt.figure(figsize=(10, 6))
bp = plt.boxplot([iters_no_prec_original, iters_no_prec, iters_prec, iters_prec_minmax], 
                 labels=['Original\n(no prec)', 'Prepared\n(no prec)', 'Prepared\n(with prec)', 'Prepared minmax\n(with prec)'])

plt.title('GMRES Iteration Counts Comparison')
plt.ylabel('Number of Iterations')
plt.grid(True, axis='y', linestyle='--', alpha=0.7)

# Add some jitter to the data points
for i, data in enumerate([iters_no_prec_original, iters_no_prec, iters_prec, iters_prec_minmax], 1):
    x = np.random.normal(i, 0.04, size=len(data))
    plt.plot(x, data, 'r.', alpha=0.2)

plt.tight_layout()
plt.show()

# Histogram
plt.figure(figsize=(8, 4))
plt.hist([iters_no_prec_original, iters_no_prec, iters_prec, iters_prec_minmax], 
         label=['Original (no prec)', 'Prepared (no prec)', 'Prepared flipped (with prec)', 'Prepared minmax (with prec)'],
         bins=20, alpha=0.7)
plt.xlabel('Number of Iterations')
plt.ylabel('Frequency')
plt.title('Distribution of GMRES Iteration Counts')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

In [None]:
# Function to plot matrices
def plot_matrices(matrices: List[np.ndarray], titles: List[str], colorbar: str = 'coolwarm') -> None:
    num_matrices = len(matrices)
    fig, axes = plt.subplots(1, num_matrices, figsize=(5 * num_matrices, 5), facecolor='none', edgecolor='none')  # Setting transparent background
    fig.patch.set_alpha(0.0)  # Making the figure background fully transparent

    for i, (matrix, title) in enumerate(zip(matrices, titles)):
        ax = axes[i] if num_matrices > 1 else axes
        sns.heatmap(matrix, cmap=colorbar, ax=ax, cbar=True)
        ax.set_title(title)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.patch.set_alpha(0.0)  # Making the axes background fully transparent

    plt.tight_layout()
    plt.show()
    
# Plot matrices
# for i in range(10):
#     plot_matrices(
#         [test_matrices[i], A_original[i], precs_true[i], np.matmul(A_original[i], precs_cnn[i])], 
#         ['Original Matrix', 'Prepared Matrix', 'Preconditioner', 'Product'], 
#         colorbar='coolwarm'
#     )

plot_matrices(test_matrices[:5], ['Original Matrix'] * 5, colorbar='rocket')