# Encrypted BSAVNN: Lung Cancer Detection using Basis Scaling and Activations Vectorized Neural Networks (Direct Execution - Cleaned)

This Jupyter Notebook integrates your original BSAVNN implementation logic and extends it to include performance evaluation and graph generation for a lung cancer detection task. **The `if __name__ == "__main__":` guard has been removed** to ensure direct execution of the main logic within the notebook environment, helping troubleshoot cases where output might not appear. **All unnecessary semicolons have also been removed.**

**Note:** For this notebook to run as intended, you should have:
- A pre-trained PyTorch model saved as `lung_cancer_model.pth`.
- Test data in `test.csv` and corresponding labels in `y_test.csv`.
- Required Python libraries (`pandas`, `torch`, `numpy`, `matplotlib`, `scikit-learn`, `psutil`).

If `lung_cancer_model.pth`, `test.csv`, or `y_test.csv` are not found, the notebook will generate dummy data and a dummy model structure. Accuracy metrics will not be meaningful in this case, but performance overhead analysis will still be valid.


In [None]:
import pandas as pd
import streamlit as st # Not directly used in the core logic, but kept as per your original script
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import LabelEncoder, StandardScaler # Not directly used in this script but kept as per original
import sys
import os
import psutil

# --- Add this line here to ensure plots appear inline in Jupyter/VS Code ---
%matplotlib inline 
# -------------------------------------------------------------------------

# --- Utility Functions for Performance Measurement ---
class MemoryMonitor:
    def __enter__(self):
        self.process = psutil.Process(os.getpid())
        self.initial_mem = self.process.memory_info().rss # in bytes
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.final_mem = self.process.memory_info().rss # in bytes

# MODIFIED: do_normal now returns y_pred_normal for classification_report
def do_normal(model, test_data, y_test_labels):
    start_time = time.perf_counter()
    y_pred_normal = []
    with torch.no_grad():
        for index, row in test_data.iterrows():
            input_tensor = torch.tensor(row.values, dtype=torch.float32).unsqueeze(0)
            output = model(input_tensor).squeeze(0) 
            predicted_class = int(output.item() > 0.5) # FIX: Convert boolean to int directly
            y_pred_normal.append(predicted_class)
    end_time = time.perf_counter()
    total_time_ms = (end_time - start_time) * 1000
    acc = accuracy_score(y_test_labels, y_pred_normal)
    
    # Return y_pred_normal so it can be used for classification_report outside
    return acc, total_time_ms, y_pred_normal

# --- YOUR ORIGINAL CODE STARTS HERE (modified ONLY for correctness and scope) ---


In [None]:
class DeepNN(nn.Module):
    def __init__(self, input_dim):
        super(DeepNN, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 48),
            nn.ReLU(),
            nn.Linear(48, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Modified: `input_dim_for_model` is now an explicit argument
def load_model(input_dim_for_model):
    saved_model = DeepNN(input_dim_for_model)
    try:
        # Check if the loaded model's input features match the current input_dim
        # FIX: Load checkpoint first to inspect its shape before loading state_dict
        checkpoint = torch.load("lung_cancer_model.pth")
        if 'model.0.weight' in checkpoint and checkpoint['model.0.weight'].shape[1] == input_dim_for_model:
            saved_model.load_state_dict(checkpoint)
            print(f"Model for input_dim={input_dim_for_model} loaded successfully from lung_cancer_model.pth.")
        else:
            print(f"Warning: Pre-trained model in .pth has input_dim {checkpoint['model.0.weight'].shape[1] if 'model.0.weight' in checkpoint else 'N/A'}, but current is {input_dim_for_model}. Initializing with random weights.")
            for param in saved_model.parameters():
                if param.dim() > 1:
                    nn.init.kaiming_uniform_(param, nonlinearity='relu')
                else:
                    nn.init.constant_(param, 0.0)
    except FileNotFoundError:
        print(f"lung_cancer_model.pth not found for input_dim={input_dim_for_model}. Initializing with random weights.")
        for param in saved_model.parameters():
            if param.dim() > 1:
                nn.init.kaiming_uniform_(param, nonlinearity='relu')
            else:
                nn.init.constant_(param, 0.0)
    except Exception as e: 
        print(f"An unexpected error occurred loading model: {e}. Initializing with random weights for input_dim={input_dim_for_model}.")
        for param in saved_model.parameters():
            if param.dim() > 1:
                nn.init.kaiming_uniform_(param, nonlinearity='relu')
            else:
                nn.init.constant_(param, 0.0)
    
    saved_model.eval() 
    return saved_model.model 


In [None]:
# The prints for saved_model_wb and its layer shapes were part of your original global execution.
# They will now run within the main execution flow after `saved_model_wb` is loaded.


### BSAVNN Core Logic
The functions below implement the basis scaling, activation vectorization, and secure ReLU logic. Their signatures have been minimally adjusted to pass necessary variables (like `an_plus_1_val` and `eval_vec`) explicitly, ensuring correct scope and execution within the comprehensive testing framework.


In [None]:
def initialize_dynamic_keys(input_dim_val):
    """Initializes encryption, decryption, and evaluation keys dynamically based on input_dim."""
    enc_vec_dyn = torch.tensor(np.random.uniform(0.9, 1.1, input_dim_val), dtype=torch.float32)
    an_plus_1_dyn = torch.tensor(np.random.uniform(1.0, 2.0), dtype=torch.float32)
    dec_vec_no_bias_dyn = 1 / enc_vec_dyn
    dec_vec_dyn = torch.cat((dec_vec_no_bias_dyn, torch.tensor([1 / an_plus_1_dyn.item()], dtype=torch.float32)))
    k_dyn = 1.0 
    eval_vec_dyn = dec_vec_dyn / k_dyn
    eval_vec_dyn = eval_vec_dyn.view(-1, 1) 
    return enc_vec_dyn, an_plus_1_dyn, dec_vec_dyn, eval_vec_dyn, k_dyn

def encrypt_data(input_values, enc_vec_arg):
    """Encrypts input values using the provided encryption vector."""
    input_tensor = torch.tensor(input_values, dtype=torch.float32) 
    enc_inp = enc_vec_arg * input_tensor 
    return enc_inp

def decrypt_ans(activation_matrix_final_layer, dec_vec_arg, k_val):
    """Decrypts the final activations to obtain probabilities."""
    dec_vec_reshaped = dec_vec_arg.view(-1, 1).to(torch.float32) 
    decrypted_output_scalar = torch.matmul(activation_matrix_final_layer.unsqueeze(0), dec_vec_reshaped).item()
    true_logit = k_val * decrypted_output_scalar
    final_ans = torch.sigmoid(torch.tensor(true_logit, dtype=torch.float32))
    return final_ans

def input_layer_calc(initial_activation_matrix_list_dummy, saved_model_wb_arg, enc_inp, eval_vec_arg, an_plus_1_val):
    """Calculates the first layer's activations with BSAVNN logic."""
    input_layer_weights = saved_model_wb_arg[0].weight.data.T.to(torch.float32) 
    input_layer_bias = saved_model_wb_arg[0].bias.data.to(torch.float32) 

    activation_matrix_current_layer = [] 

    for col_idx in range(input_layer_weights.shape[1]):
        col_vector = input_layer_weights[:, col_idx]
        elementwise_product = enc_inp * col_vector
        
        scaled_bias_term = (an_plus_1_val * input_layer_bias[col_idx]).unsqueeze(0)
        elementwise_product = torch.cat((elementwise_product, scaled_bias_term))
        
        activation_matrix_current_layer.append(elementwise_product)

    activation_matrix_current_layer_tensor = torch.stack(activation_matrix_current_layer)

    relu_output = []
    for row_vector in activation_matrix_current_layer_tensor:
        row_sum = torch.matmul(row_vector.unsqueeze(0), eval_vec_arg).item() 
        relu_value = F.relu(torch.tensor(row_sum, dtype=torch.float32))
        relu_output.append(relu_value)

    relu_output_tensor = torch.tensor(relu_output, dtype=torch.float32)
    final_activation_matrix = torch.stack([row_vector if relu_value > 0 else torch.zeros_like(row_vector) 
                                          for row_vector, relu_value in zip(activation_matrix_current_layer_tensor, relu_output_tensor)])
    return final_activation_matrix

def intermediate_layer_calc(activation_matrix, saved_model_wb_arg, eval_vec_arg, an_plus_1_val, layer_index):
    """Calculates intermediate layer activations with BSAVNN logic."""
    intermediate_weights = saved_model_wb_arg[layer_index].weight.data.T.to(torch.float32) 
    intermediate_biases = saved_model_wb_arg[layer_index].bias.data.to(torch.float32) 
    
    activation_matrix_new = []

    for col_idx in range(intermediate_weights.shape[1]):
        col_vector = intermediate_weights[:, col_idx]
        scaled_activation = activation_matrix * col_vector.view(-1,1)
        row_vector = scaled_activation.sum(dim=0)
        
        row_vector[-1] += (an_plus_1_val * intermediate_biases[col_idx]).item() 
        
        activation_matrix_new.append(row_vector)

    activation_matrix_new_tensor = torch.stack(activation_matrix_new)

    is_final_linear_layer = (layer_index == 10)

    if not is_final_linear_layer:
        relu_output = []
        for row_vector in activation_matrix_new_tensor:
            row_sum = torch.matmul(row_vector.unsqueeze(0), eval_vec_arg).item()
            relu_value = F.relu(torch.tensor(row_sum, dtype=torch.float32))
            relu_output.append(relu_value)

        relu_output_tensor = torch.tensor(relu_output, dtype=torch.float32)
        activation_matrix_out = torch.stack([row_vector if relu_value > 0 else torch.zeros_like(row_vector) 
                                             for row_vector, relu_value in zip(activation_matrix_new_tensor, relu_output_tensor)])
    else:
        activation_matrix_out = activation_matrix_new_tensor 
        
    return activation_matrix_out

def bsavnn_model_pred(input_values, saved_model_wb_arg, enc_vec_arg, eval_vec_arg, an_plus_1_val):
    """Performs a full BSAVNN forward pass for one sample."""
    enc_inp = encrypt_data(input_values, enc_vec_arg) 

    current_activation_matrix = input_layer_calc([], saved_model_wb_arg, enc_inp, eval_vec_arg, an_plus_1_val)

    for layer_idx in range(2, 11, 2): 
        current_activation_matrix = intermediate_layer_calc(current_activation_matrix, saved_model_wb_arg, eval_vec_arg, an_plus_1_val, layer_idx)
    
    return current_activation_matrix 


## Main Execution Block and Performance Evaluation
This section orchestrates the loading of data and models, runs the BSAVNN and plaintext inference, and generates the required performance graphs.


In [None]:
# Helper function for running evaluation for varying input dimensions
def run_evaluation_for_input_dim(current_input_dim, num_samples=100):
    dummy_test_data = pd.DataFrame(np.random.rand(num_samples, current_input_dim))
    dummy_y_test = [np.random.randint(0, 2) for _ in range(num_samples)] 

    # Initialize keys and model for THIS specific input_dim
    enc_vec_local, an_plus_1_local, dec_vec_local, eval_vec_local, k_local = initialize_dynamic_keys(current_input_dim)
    saved_model_wb_local = load_model(current_input_dim) 

    # Plaintext Model Evaluation
    _, normal_inference_time_ms, _ = do_normal(DeepNN(current_input_dim), dummy_test_data, dummy_y_test) # Discard y_pred_normal as it's not used here
    
    # BSAVNN Model Evaluation
    total_bsavnn_time_ms = 0
    total_encrypt_time_ms = 0
    total_decrypt_time_ms = 0
    total_bsavnn_server_inference_time_ms = 0 

    # For memory, we compute theoretical activation memory (as discussed in your paper)
    # Max neurons in hidden layers: 64 (from input_dim to 64)
    bsavnn_max_activations_elements = 64 * (current_input_dim + 1) 
    bsavnn_activation_mem_bytes = bsavnn_max_activations_elements * 4 
    
    # Plaintext NN: Max intermediate layer size (e.g., 64 neurons * 1 element)
    plaintext_max_activations_elements = 64 
    plaintext_activation_mem_bytes = plaintext_max_activations_elements * 4

    for idx in range(num_samples):
        input_row = dummy_test_data.iloc[idx].values

        start_encrypt = time.perf_counter()
        encrypted_input = encrypt_data(input_row, enc_vec_local) 
        end_encrypt = time.perf_counter()
        total_encrypt_time_ms += (end_encrypt - start_encrypt) * 1000

        start_inference = time.perf_counter()
        final_encrypted_logit_vector = bsavnn_model_pred(input_row, saved_model_wb_local, enc_vec_local, eval_vec_local, an_plus_1_local)
        end_inference = time.perf_counter()
        total_bsavnn_server_inference_time_ms += (end_inference - start_inference) * 1000

        start_decrypt = time.perf_counter()
        # FIX: Access the single output vector correctly for decryption
        _ = decrypt_ans(final_encrypted_logit_vector[0], dec_vec_local, k_local) 
        end_decrypt = time.perf_counter()
        total_decrypt_time_ms += (end_decrypt - start_decrypt) * 1000

    return {
        'input_dim': current_input_dim,
        'normal_inference_time_ms': normal_inference_time_ms / num_samples,
        'bsavnn_total_inference_time_ms': (total_bsavnn_encrypt_time + total_bsavnn_server_inference_time_ms + total_bsavnn_decrypt_time) / num_samples,
        'bsavnn_encrypt_time_ms': total_encrypt_time_ms / num_samples,
        'bsavnn_decrypt_time_ms': total_decrypt_time_ms / num_samples,
        'bsavnn_server_inference_time_ms': total_bsavnn_server_inference_time_ms / num_samples,
        'bsavnn_activation_mem_mb': bsavnn_activation_mem_bytes / (1024**2),
        'plaintext_activation_mem_mb': plaintext_activation_mem_bytes / (1024**2),
    }

# Helper to plot graphs (Defined here, outside the main execution block)
def plot_performance_graphs(results_df_plot):
    plt.style.use('seaborn-v0_8-darkgrid')
    fig_width = 10
    fig_height = 6

    os.makedirs('figures', exist_ok=True)

    # Time vs. Input Size
    plt.figure(figsize=(fig_width, fig_height))
    plt.plot(results_df_plot['input_dim'], results_df_plot['normal_inference_time_ms'], label='Plaintext NN', marker='o', color='blue')
    plt.plot(results_df_plot['input_dim'], results_df_plot['bsavnn_total_inference_time_ms'], label='BSAVNN (Total End-to-End)', marker='x', color='red')
    plt.plot(results_df_plot['input_dim'], results_df_plot['bsavnn_server_inference_time_ms'], label='BSAVNN (Server Only)', marker='s', color='green', linestyle='--')
    plt.xlabel('Input Dimension ($n$)')
    plt.ylabel('Inference Time (ms)')
    plt.title('Computation Time vs. Input Dimension')
    plt.legend()
    plt.xscale('log')
    plt.yscale('log')
    plt.grid(True, which="both", ls="--", c='0.7')
    plt.tight_layout()
    plt.savefig('./figures/time_vs_input_size.png')
    plt.show()

    # Memory Usage vs. Input Size
    plt.figure(figsize=(fig_width, fig_height))
    plt.plot(results_df_plot['input_dim'], results_df_plot['plaintext_activation_mem_mb'], label='Plaintext NN (Activations)', marker='o', color='blue')
    plt.plot(results_df_plot['input_dim'], results_df_plot['bsavnn_activation_mem_mb'], label='BSAVNN (Activations)', marker='x', color='red')
    plt.xlabel('Input Dimension ($n$)')
    plt.ylabel('Memory Usage (MB)')
    plt.title('Activation Memory Usage vs. Input Dimension')
    plt.legend()
    plt.xscale('log')
    plt.yscale('log')
    plt.grid(True, which="both", ls="--", c='0.7')
    plt.tight_layout()
    plt.savefig('./figures/memory_vs_input_size.png')
    plt.show()

    # Relative Overhead vs. Input Size
    plt.figure(figsize=(fig_width, fig_height))
    overhead_percentage = ((results_df_plot['bsavnn_total_inference_time_ms'] - results_df_plot['normal_inference_time_ms']) / results_df_plot['normal_inference_time_ms']) * 100
    plt.plot(results_df_plot['input_dim'], overhead_percentage, label='BSAVNN Relative Time Overhead', marker='o', color='purple')
    plt.xlabel('Input Dimension ($n$)')
    plt.ylabel('Relative Overhead (%)')
    plt.title('BSAVNN Relative Time Overhead vs. Input Dimension')
    plt.legend()
    plt.xscale('log')
    plt.grid(True, which="both", ls="--", c='0.7')
    plt.tight_layout()
    plt.savefig('./figures/relative_overhead_vs_input_size.png')
    plt.show()

# --- Main Execution Logic (now directly executed without if __name__ guard) ---

# --- Configuration ---
actual_input_dim = 15 # Default for `test.csv` if not found/overridden
paper_input_dim = 16384 # As specified in your paper for the 128x128 image (e.g., for plotting)
num_samples_for_graphs = 50 # Number of samples to use for each input_dim when generating graphs

# --- Data Loading and Initialization ---
test_data = pd.DataFrame()
y_test = np.array([])
try:
    test_data_path = "test.csv"
    y_test_path = "y_test.csv"
    test_data = pd.read_csv(test_data_path)
    y_test = pd.read_csv(y_test_path).values.flatten()
    actual_input_dim = test_data.shape[1] 
    print(f"Loaded test data with {actual_input_dim} features.")
except FileNotFoundError:
    print("Test data CSVs not found. Generating dummy data for accuracy test.")
    actual_input_dim = 15 
    test_data = pd.DataFrame(np.random.rand(100, actual_input_dim))
    y_test = np.random.randint(0, 2, 100)

# Initialize global keys for the main execution block (accuracy test)
global enc_vec, an_plus_1, dec_vec, k, eval_vec 
enc_vec, an_plus_1, dec_vec, eval_vec, k = initialize_dynamic_keys(actual_input_dim)

# Load the plaintext model's sequential layers to global variable
global saved_model_wb 
saved_model_wb = load_model(actual_input_dim)

# --- BSAVNN Single Inference Demonstration ---
print("\n--- BSAVNN Single Inference Demonstration ---")
if not test_data.empty:
    sample_input_values = test_data.iloc[0].values
    
    start_encryption_time = time.perf_counter()
    encrypted_sample_input = encrypt_data(sample_input_values, enc_vec) 
    end_encryption_time = time.perf_counter()
    
    start_bsavnn_inference_time = time.perf_counter()
    final_encrypted_logit_vector_sample = bsavnn_model_pred(sample_input_values, saved_model_wb, enc_vec, eval_vec, an_plus_1)
    end_bsavnn_inference_time = time.perf_counter()
    
    start_decryption_time = time.perf_counter()
    final_predicted_prob_sample = decrypt_ans(final_encrypted_logit_vector_sample[0], dec_vec, k) 
    end_decryption_time = time.perf_counter()

    print(f"Sample Input (first 5 features): {sample_input_values[:5]}")
    print(f"Encrypted Input (first 5 features): {encrypted_sample_input[:5].tolist()}")
    print(f"Final Encrypted Logit Vector (truncated): {final_encrypted_logit_vector_sample[0, :min(5, final_encrypted_logit_vector_sample.shape[1])].tolist()}")
    print(f"Decrypted Final Probability for sample: {final_predicted_prob_sample.item():.4f}")
    print(f"Encryption Time: {(end_encryption_time - start_encryption_time) * 1000:.3f} ms")
    print(f"Server Inference Time: {(end_bsavnn_inference_time - start_bsavnn_inference_time) * 1000:.3f} ms")
    print(f"Decryption Time: {(end_decryption_time - start_decryption_time) * 1000:.3f} ms")
else:
    print("No test data loaded for single inference demonstration.")

# --- Run full BSAVNN inference for accuracy and full timing ---
print("\n--- Running Full BSAVNN Inference for Metrics ---")
y_pred_bsavnn = []
total_bsavnn_encrypt_time = 0
total_bsavnn_server_time = 0
total_bsavnn_decrypt_time = 0

for index, row in test_data.iterrows():
    start_time = time.perf_counter()
    encrypted_input_row = encrypt_data(row.values, enc_vec)
    end_time = time.perf_counter()
    total_bsavnn_encrypt_time += (end_time - start_time) * 1000

    start_time = time.perf_counter()
    final_encrypted_logit_vector_row = bsavnn_model_pred(row.values, saved_model_wb, enc_vec, eval_vec, an_plus_1)
    end_time = time.perf_counter()
    total_bsavnn_server_time += (end_time - start_time) * 1000

    start_time = time.perf_counter()
    final_prob_row = decrypt_ans(final_encrypted_logit_vector_row[0], dec_vec, k)
    end_time = time.perf_counter()
    total_bsavnn_decrypt_time += (end_time - start_time) * 1000

    predicted_class = int(final_prob_row.item() > 0.5)
    y_pred_bsavnn.append(predicted_class)

bsavnn_acc = accuracy_score(y_test, y_pred_bsavnn)
print("\nâœ… BSAVNN Test Accuracy:", round(bsavnn_acc * 100, 2), "%")
print("\nðŸ“Š BSAVNN Classification Report:\n", classification_report(y_test, y_pred_bsavnn, zero_division=0))
print(f"BSAVNN Average Encryption Time: {total_bsavnn_encrypt_time / len(test_data):.3f} ms/sample")
print(f"BSAVNN Average Server Inference Time: {total_bsavnn_server_time / len(test_data):.3f} ms/sample")
print(f"BSAVNN Average Decryption Time: {total_bsavnn_decrypt_time / len(test_data):.3f} ms/sample")
print(f"BSAVNN Average Total End-to-End Time: {(total_bsavnn_encrypt_time + total_bsavnn_server_time + total_bsavnn_decrypt_time) / len(test_data):.3f} ms/sample")

# --- Run Plaintext inference for comparison ---
print("\n--- Running Plaintext Inference for Comparison ---")
original_plaintext_model = DeepNN(actual_input_dim)
try:
    original_plaintext_model.load_state_dict(torch.load("lung_cancer_model.pth"))
    original_plaintext_model.eval()
except FileNotFoundError:
    print("Warning: lung_cancer_model.pth not found for plaintext model. Using random weights.")
    for param in original_plaintext_model.parameters():
        if param.dim() > 1:
            nn.init.kaiming_uniform_(param, nonlinearity='relu')
        else:
            nn.init.constant_(param, 0.0)
    
normal_acc, normal_time_ms_total, y_pred_normal = do_normal(original_plaintext_model, test_data, y_test) 
print("Normal Model Accuracy:", round(normal_acc * 100, 2), "%")
print("\nðŸ“Š Normal Classification Report:\n", classification_report(y_test, y_pred_normal, zero_division=0)) 
print(f"Normal Model Average Inference Time: {normal_time_ms_total / len(test_data):.3f} ms/sample")


# --- Generate Data for Graphs (varying input dimensions) ---
print("\n--- Generating Data for Performance Graphs (this may take a while for large dims) ---")
input_dims_to_test = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
if paper_input_dim not in input_dims_to_test and paper_input_dim > max(input_dims_to_test):
    input_dims_to_test.append(paper_input_dim)
input_dims_to_test.sort()

performance_results = []
for dim in input_dims_to_test:
    print(f"Evaluating for input_dim={dim} with {num_samples_for_graphs} samples...")
    result = run_evaluation_for_input_dim(dim, num_samples=num_samples_for_graphs)
    performance_results.append(result)

results_df = pd.DataFrame(performance_results)
print("\nPerformance Results Data:")
print(results_df.to_string()) 

# --- Plot Graphs ---
print("\n--- Plotting Performance Graphs ---")

os.makedirs('figures', exist_ok=True)

def plot_performance_graphs(results_df_plot):
    plt.style.use('seaborn-v0_8-darkgrid')
    fig_width = 10
    fig_height = 6

    # Time vs. Input Size
    plt.figure(figsize=(fig_width, fig_height))
    plt.plot(results_df_plot['input_dim'], results_df_plot['normal_inference_time_ms'], label='Plaintext NN', marker='o', color='blue')
    plt.plot(results_df_plot['input_dim'], results_df_plot['bsavnn_total_inference_time_ms'], label='BSAVNN (Total End-to-End)', marker='x', color='red')
    plt.plot(results_df_plot['input_dim'], results_df_plot['bsavnn_server_inference_time_ms'], label='BSAVNN (Server Only)', marker='s', color='green', linestyle='--')
    plt.xlabel('Input Dimension ($n$)')
    plt.ylabel('Inference Time (ms)')
    plt.title('Computation Time vs. Input Dimension')
    plt.legend()
    plt.xscale('log')
    plt.yscale('log')
    plt.grid(True, which="both", ls="--", c='0.7')
    plt.tight_layout()
    plt.savefig('./figures/time_vs_input_size.png')
    plt.show()

    # Memory Usage vs. Input Size
    plt.figure(figsize=(fig_width, fig_height))
    plt.plot(results_df_plot['input_dim'], results_df_plot['plaintext_activation_mem_mb'], label='Plaintext NN (Activations)', marker='o', color='blue')
    plt.plot(results_df_plot['input_dim'], results_df_plot['bsavnn_activation_mem_mb'], label='BSAVNN (Activations)', marker='x', color='red')
    plt.xlabel('Input Dimension ($n$)')
    plt.ylabel('Memory Usage (MB)')
    plt.title('Activation Memory Usage vs. Input Dimension')
    plt.legend()
    plt.xscale('log')
    plt.yscale('log')
    plt.grid(True, which="both", ls="--", c='0.7')
    plt.tight_layout()
    plt.savefig('./figures/memory_vs_input_size.png')
    plt.show()

    # Relative Overhead vs. Input Size
    plt.figure(figsize=(fig_width, fig_height))
    overhead_percentage = ((results_df_plot['bsavnn_total_inference_time_ms'] - results_df_plot['normal_inference_time_ms']) / results_df_plot['normal_inference_time_ms']) * 100
    plt.plot(results_df_plot['input_dim'], overhead_percentage, label='BSAVNN Relative Time Overhead', marker='o', color='purple')
    plt.xlabel('Input Dimension ($n$)')
    plt.ylabel('Relative Overhead (%)')
    plt.title('BSAVNN Relative Time Overhead vs. Input Dimension')
    plt.legend()
    plt.xscale('log')
    plt.grid(True, which="both", ls="--", c='0.7')
    plt.tight_layout()
    plt.savefig('./figures/relative_overhead_vs_input_size.png')
    plt.show()

plot_performance_graphs(results_df) # Call the plotting function with the results DataFrame
print("Graphs saved as time_vs_input_size.png, memory_vs_input_size.png, relative_overhead_vs_input_size.png in the 'figures' directory.")
