# Data Driven Format Selection
---


### Ensure CUDA is enabled

In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")

### Imports

In [None]:
%%time

from helpers import (
    SparseMatrixType as smt,
    SparseMatrixGenerator as smg,
    SparseMatrixVisualizer as smv,
    SparseMatrixBenchmark as smb,
    SparseMatrixAnalysis as sma,
    SparseMatrixPredictor as smp
)

### Generate training and testing sets

In [None]:
%%time

generator = smg()

train_files, val_files = generator.generate_train_val_sets(
    sizes=[512, 1024, 2048],
    types=[smt.RANDOM, smt.BANDED, smt.BLOCK_DIAGONAL, smt.TRIDIAGONAL, smt.CHECKERBOARD],
    num_train=500,
    num_val=100,
    train_prefix="train_matrix",
    val_prefix="test_matrix",
    filename_fmt="{prefix}_size{size}_{type}_{index:08d}.mtx",
    overwrite=True
)

print(f"Generated {len(train_files)} training files")
print(f"Generated {len(val_files)} validation files")

### Benchmark the training set

In [None]:
%%time

benchmark_training = smb("./train")

In [None]:
df_training = benchmark_training.results_df
df_training

In [None]:
benchmark_training.plot_results("results_plot.png")

### Perform ranking

In [None]:
analyzer_training = sma(df_training)

In [None]:
optimal_formats_training = analyzer_training.get_optimal_formats()
optimal_formats_training

In [None]:
analyzer_training.print_analysis_summary()

### Develope tree

In [None]:
predictor = smp(optimal_formats_training)

In [None]:
accuracy = predictor.model.score(predictor.X_test, predictor.y_test)
print(f"Model accuracy: {accuracy:.2f}")

In [None]:
predictor.visualize_tree("decision_tree.png")

In [None]:
predictor.get_feature_importance()

In [None]:
import pandas as pd

metrics = predictor.evaluate_model()

print(f"Accuracy: {metrics['accuracy']:.2f}")
print(f"Macro F1: {metrics['f1_macro']:.2f}")

for class_name in predictor.class_names:
    if f'f1_{class_name}' in metrics:
        print(f"{class_name} F1: {metrics[f'f1_{class_name}']:.2f}")

if metrics['classification_report'] is not None:
    print("\nClassification Report:")
    print(pd.DataFrame(metrics['classification_report']).transpose())

In [None]:
predictor.plot_confusion_matrix("Confusion_matrix")

### Predict optimal formats for test set

In [None]:
df_predictions = predictor.predict_formats_for_folder("./test")
df_predictions

### Benchmark the test set

In [None]:
benchmark_testing = smb("./test")

In [None]:
df_testing = benchmark_testing.results_df
df_testing

### Perform ranking

In [None]:
analyzer_testing = sma(df_testing)

In [None]:
optimal_formats_testing = analyzer_testing.get_optimal_formats()
optimal_formats_testing

### Clean up

In [None]:
predictions_only_testing = df_predictions[['filename', 'predicted_format']]
predictions_only_testing

In [None]:
optimal_only_testing = optimal_formats_testing[['filename', 'format']]
optimal_only_testing

### Prediction Accuracy

In [None]:
import pandas as pd

def compare_format_predictions(actual_df, predicted_df):
    try:
        for df, name, expected_column in [(actual_df, 'actual', 'format'), 
                                          (predicted_df, 'predicted', 'predicted_format')]:
            if 'filename' not in df.columns or expected_column not in df.columns:
                raise ValueError(f"The {name} dataframe must have 'filename' and '{expected_column}' columns")
        
        actual_formats = dict(zip(actual_df['filename'], actual_df['format']))
        
        correct = 0
        total = 0
        comparison_results = []
        
        for _, row in predicted_df.iterrows():
            filename = row['filename']
            predicted_format = row['predicted_format']
            
            if filename in actual_formats:
                total += 1
                actual_format = actual_formats[filename]
                is_correct = (predicted_format == actual_format)
                
                if is_correct:
                    correct += 1
                
                comparison_results.append({
                    'filename': filename,
                    'actual_format': actual_format,
                    'predicted_format': predicted_format,
                    'correct_prediction': is_correct
                })
        
        accuracy = (correct / total) * 100 if total > 0 else 0
        
        print(f"Total files compared: {total}")
        print(f"Correct predictions: {correct}")
        print(f"Incorrect predictions: {total - correct}")
        print(f"Prediction accuracy: {accuracy:.2f}%")
        
        return accuracy
        
    except Exception as e:
        print(f"Error: {e}")

In [None]:
accuracy = compare_format_predictions(optimal_only_testing, predictions_only_testing)