# HDNN-ArtifactBrainState
HDNN-ArtifactBrainState is a cutting-edge project that integrates Hopfield networks with deep neural networks to enhance brain state decoding, focusing on resilience against artifacts. This notebook, and associated utilities, house the code and resources for implementing the HDNN framework described in our recent publication.

<img src="data/HDNN-pipeline.png" width="1000px">


## Importing the code and defining the parameters 

In [None]:
import utils as fx
from numpy import mean
import pickle

subj = 'subject1' # name of our dataset (only subject1 for this test-case)
num_samples_per_state = 100 # demo dataset has up to 288 samples per each state
segment_length = 250 # 250 millisecond windows for our demo dataset
n_epochs = 10 

imgs = [10,20] # size of the images
err_percentages = [0.05, 0.1] # percentage of total pixel corruption to be added in the image

## Case 1: No Artifacts are added to the data.
The following code will run through different imgs sizes and err_percentages. As it is the "No Artifacts" case, the data won't include any type of artifacts.

A file `res_noartifact.pkl` will be created to save the accuracies and confusion matrices for each case. You can access that results by loading the file or accessing the variable `res_noartifact`

In [None]:
import pickle

# Load res_noartifact from disk if it exists
try:
    with open('res_noartifact.pkl', 'rb') as f:
        res_noartifact = pickle.load(f)
except FileNotFoundError:
    res_noartifact = {}  # Initialize an empty dictionary if file not found

for err in err_percentages:
    for img_size in imgs:
        if (err, img_size) in res_noartifact:
            print('skipped!', err, img_size)
            continue

        temp_test_acc = []
        temp_cm_normalized = []
        
        n = 2 # Change this to n runs
        for _ in range(n):  
            test_acc, cm_normalized = fx.pipeline_noartifact(subj, num_samples_per_state, img_size, segment_length, err, n_epochs)
            temp_test_acc.append(test_acc)
            temp_cm_normalized.append(cm_normalized)

        # average values after n runs
        avg_test_acc = mean(temp_test_acc)
        avg_cm_normalized = mean(temp_cm_normalized, axis=0)

        # Save the averaged results in the dictionary
        res_noartifact[(err, img_size)] = {'test_acc': avg_test_acc, 'cm_normalized': avg_cm_normalized}

        # Save to disk after every iteration
        with open('res_noartifact.pkl', 'wb') as f:
            pickle.dump(res_noartifact, f)


## Case 2: Hopfield + CNN model
The following code will run through different imgs sizes and err_percentages. Artifacts are added to the data depending on the `err_percentages` chosen and treated via the Hopfield model.

A file `results_hop.pkl` will be created to save the accuracies and confusion matrices for each case. You can access that results by loading the file or accessing the variable `results_2hop`

In [None]:
import pickle

try:
    with open('results_hop.pkl', 'rb') as f:
        results_2hop = pickle.load(f)
except FileNotFoundError:
    results_2hop = {}  # Initialize an empty dictionary if file not found

for err in err_percentages:
    for img_size in imgs:
        
        # Skip computation if this combination already exists
        if (err, img_size) in results_2hop:
            print('skipped!', err, img_size)
            continue

        temp_test_acc = []
        temp_cm_normalized = []

        for _ in range(2):  # Change this to 5 if you meant 5 runs
            test_acc, cm_normalized = fx.pipeline_hopfield_rec(subj, num_samples_per_state, img_size, segment_length, err, n_epochs)
            temp_test_acc.append(test_acc)
            temp_cm_normalized.append(cm_normalized)

        avg_test_acc = mean(temp_test_acc)
        avg_cm_normalized = mean(temp_cm_normalized, axis=0)

        # Save the averaged results in the dictionary
        results_2hop[(err, img_size)] = {'test_acc': avg_test_acc, 'cm_normalized': avg_cm_normalized}

        # Save to disk after every iteration
        with open('results_hop.pkl', 'wb') as f:
            pickle.dump(results_2hop, f)

## Case 3: Baseline
The following code will run through different imgs sizes and err_percentages. Artifacts are added to the data depending on the `err_percentages` chosen. Artifacts are not treated by the Hopfield model. 

A file `results.pkl` will be created to save the accuracies and confusion matrices for each case. You can access that results by loading the file or accessing the variable `results`

In [None]:
import pickle

# Load results from disk if it exists
try:
    with open('results.pkl', 'rb') as f:
        results = pickle.load(f)
except FileNotFoundError:
    results = {}  # Initialize an empty dictionary if file not found

for err in err_percentages:
    for img_size in imgs:

        # Skip computation if this combination already exists
        if (err, img_size) in results:
            print('skipped!', err, img_size)
            continue

        temp_test_acc = []  
        temp_cm_normalized = [] 

        n = 2 # number of runs
        for _ in range(n): 
            test_acc, cm_normalized = fx.pipeline(subj, num_samples_per_state, img_size, segment_length, err, n_epochs)
            temp_test_acc.append(test_acc)
            temp_cm_normalized.append(cm_normalized)

        avg_test_acc = mean(temp_test_acc)
        avg_cm_normalized = mean(temp_cm_normalized, axis=0)

        results[(err, img_size)] = {'test_acc': avg_test_acc, 'cm_normalized': avg_cm_normalized}

        with open('results.pkl', 'wb') as f:
            pickle.dump(results, f)

## Plotting 
- Load the results of the three models
- Compare the accuracies of the three models for the different `err_percentages` and `img_size` chosen

In [None]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np

with open('res_noartifact.pkl', 'rb') as f:
    res_noartifact = pickle.load(f)

with open('results.pkl', 'rb') as f:
    results = pickle.load(f)

with open('results_hop.pkl', 'rb') as f:
    results_2hop_v2 = pickle.load(f)

imgs = [10, 20]
err_percentages = err_percentages

default_value = {'test_acc': -0, 'cm_normalized': np.array([[-0, -0, -1], [-1, -1, -1], [-1, -1, -1]]), 'cm_m': -0}

datasets = [results, results_2hop_v2, res_noartifact]

# Loop through each dataset to populate missing (err, img) combinations
for dataset in datasets:
    for img in imgs:
        for err in err_percentages:
            if (err, img) not in dataset:
                dataset[(err, img)] = default_value  # Using default_value with 'cm_m'
            cm_normalized = dataset[(err, img)]['cm_normalized']
            cm_m = np.mean(np.diag(cm_normalized))
            dataset[(err, img)]['cm_m'] = cm_m  # Storing the mean of the diagonal

# Setting plot aesthetics
font = {'family': 'Arial', 'weight': 'normal', 'size': 18}
plt.rc('font', **font)
fig, axes = plt.subplots(1, 3, figsize=(24, 8))

datasets = [results,results_2hop_v2,res_noartifact]
colors = ['g', 'b', 'r']
titles = ['Baseline', 'Hopfield-rec', 'No-artifacts']

datasets = [results,results_2hop_v2,res_noartifact]
colors = ['g', 'b', 'r']
dataset_names = ['Baseline', 'Hopfield-rec', 'No-artifacts']

for idx, (data, color, title) in enumerate(zip(datasets, colors, titles)):

    ax = axes[idx]
    for img_size in [10, 20]:
        acc_values = [data[(err, img_size)]['cm_m'] for err in err_percentages if (err, img_size) in data]

        ax.plot(err_percentages, acc_values, marker='o', label=f"Image Size {img_size}")
    
    ax.set_xticks(err_percentages)
    ax.set_yticks([0.25, 0.50, 0.75, 1])
    ax.set_ylim([0, 1])
    ax.set_title(f"{title} - Diagonal CM vs. Error")
    ax.set_xlabel('Error Percentage')
    ax.set_ylabel('Test Mean of diagonal CM  (%)')
    ax.legend()
    ax.grid(True, linestyle='--')

fig, axes = plt.subplots(1, 3, figsize=(24, 8))

for idx, (data, color, title) in enumerate(zip(datasets, colors, titles)):
    ax = axes[idx]
    for err in err_percentages:
        acc_values = [data[(err, img_size)]['cm_m'] for img_size in imgs]
        ax.plot(imgs, acc_values, marker='o', label=f"Error {err}")

    ax.set_xticks(imgs)
    ax.set_yticks([0.25, 0.50, 0.75, 1])
    ax.set_ylim([0, 1])
    ax.set_title(f"{title} - Diagonal CM vs. Image Size")
    ax.set_xlabel('Image Size')
    ax.set_ylabel('Test Mean of diagonal CM  (%)')
    ax.grid(True, linestyle='--')
    ax.legend()


merged_results = {}

datasets = [results_2hop, res_noartifact, results]
dataset_names = ['Hopfield-rec', 'No-artifacts', 'Baseline']
font = {'family': 'Times New Roman', 'weight': 'normal', 'size': 18}
plt.rc('font', **font)


fig, ax = plt.subplots(ncols=2, figsize=(16, 4))

datasets = [results,results_2hop_v2,res_noartifact]
colors = ['g', 'b', 'r']
titles = ['Baseline', 'Hopfield-rec', 'No-artifacts']
datasets = [results,results_2hop_v2,res_noartifact]
colors = ['g', 'b', 'r']
dataset_names = ['Baseline', 'Hopfield-rec', 'No-artifacts']


for idx, (data, color, title) in enumerate(zip(datasets, colors, titles)):
    for img_size in [10]:
        acc_values = [data[(err, img_size)]['cm_m'] for err in err_percentages if (err, img_size) in data]
        ax[0].plot(err_percentages, acc_values, marker='o', label=title, color=color)

for idx, (data, color, title) in enumerate(zip(datasets, colors, titles)):
    for img_size in [20]:
        acc_values = [data[(err, img_size)]['cm_m'] for err in err_percentages if (err, img_size) in data]
        ax[1].plot(err_percentages, acc_values, marker='o', label=title, color=color)

ax[0].set_xticks(err_percentages)
ax[0].set_yticks([0.4, 0.75, 1])
ax[0].set_ylim([0, 1])
ax[0].set_title(f'Mean of diagonal CM  vs. Error img_size={10}')
ax[0].set_xlabel('Error Percentage')
ax[0].set_ylabel('Test Mean of diagonal CM  (%)')
ax[0].grid(True, linestyle='--')
ax[0].legend()
ax[1].set_xticks(err_percentages)
ax[1].set_yticks([0.4, 0.75, 1])
ax[1].set_ylim([0, 1])
ax[1].set_title(f'Mean of diagonal CM  vs. Error img_size={20}')
ax[1].set_xlabel('Error Percentage')
ax[1].set_ylabel('Test Mean of diagonal CM  (%)')
ax[1].grid(True, linestyle='--')
ax[1].legend()
