<a href="https://colab.research.google.com/github/ReveRoyl/MT_ML_Decoding/blob/main/CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CNN

## Package installation and import

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install mne
!pip install sklearn
!pip install tensorflow
!pip install -U dm-haiku
!pip install optax

In [None]:
import sys
sys.path.append('drive/MyDrive/load/code')
from load_data import load_MEG_dataset
import haiku as hk
import jax
import optax
from jax import numpy as jnp
import numpy as np
from sklearn.metrics import classification_report

## load data

In [None]:
# X, y = load_MEG_dataset([str(i).zfill(3) for i in range(1,5)])
X_train, y_train = load_MEG_dataset([str(i).zfill(3) for i in range(1,5)], mode = 'concatenate', output_format='numpy')
X_test, y_test = load_MEG_dataset([str(i).zfill(3) for i in range(1,5)], mode = 'concatenate', output_format='numpy')
X_train, X_test, y_train, y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(y_train, dtype=jnp.float32),\
                                   jnp.array(y_test, dtype=jnp.float32)
print('loading done')

In [None]:
np.isnan(X_test).any()

## CNN process

In [None]:
classes =  jnp.unique(y_train)
class CNN(hk.Module):
    def __init__(self):
        super().__init__(name="CNN")
        self.conv1 = hk.Conv2D(output_channels=32, kernel_shape=(3,3), padding="SAME")
        self.conv2 = hk.Conv2D(output_channels=16, kernel_shape=(3,3), padding="SAME")
        self.flatten = hk.Flatten()
        self.linear = hk.Linear(len(classes))

    def __call__(self, x_batch):
        x = self.conv1(x_batch)
        x = hk.MaxPool(window_shape=(2, 2), strides=(2, 2), padding='SAME')(x)
        x = jax.nn.relu(x)
        x = self.conv2(x)
        x = jax.nn.relu(x)
        x = hk.MaxPool(window_shape=(2, 2), strides=(2, 2), padding='SAME')(x)
        x = self.flatten(x)
        x = self.linear(x)
        x = jax.nn.softmax(x)
        return x

def ConvNet(x):
    cnn = CNN()
    return cnn(x)

conv_net = hk.transform(ConvNet)        

rng = jax.random.PRNGKey(42)

params = conv_net.init(rng, X_train[:5])

print("Weights Type : {}\n".format(type(params)))

for layer_name, weights in params.items():
    print(layer_name)
    print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))


preds = conv_net.apply(params, rng, X_train[:5])

preds[:5]

Loss function and Weights update function

In [None]:
def CrossEntropyLoss(weights, input_data, actual):
    preds = conv_net.apply(weights, rng, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)
def UpdateWeights(weights,gradients):
    return weights - learning_rate * gradients

Train

In [None]:
from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.

conv_net = hk.transform(ConvNet)
params = conv_net.init(rng, X_train[:5])
epochs = 25
batch_size = 256
learning_rate = jnp.array(1/1e4)


optimizer = optax.adam(learning_rate=learning_rate) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(params)


for i in range(1, epochs+1):
    batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices

    losses = [] ## Record loss of each batch
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch, Y_batch = X_train[start:end], y_train[start:end] ## Single batch of data

        loss, param_grads = value_and_grad(CrossEntropyLoss)(params, X_batch, Y_batch) ## Forward pass, loss and grads calculation
        #print(param_grads)
        updates, optimizer_state = optimizer.update(param_grads, optimizer_state) ## Calculate parameter updates
        params = optax.apply_updates(params, updates) ## Update model weights
        #params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
        losses.append(loss) ## Record Loss

    print("CrossEntropy Loss : {:.2f}".format(jnp.array(losses).mean()))

Make prediction

In [None]:
def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        preds.append(conv_net.apply(weights, rng, X_batch))

    return preds

In [None]:
train_preds = MakePredictions(params, X_train, 256)
train_preds = jnp.concatenate(train_preds).squeeze()
train_preds = train_preds.argmax(axis=1)

test_preds = MakePredictions(params, X_test, 256)
test_preds = jnp.concatenate(test_preds).squeeze()
test_preds = test_preds.argmax(axis=1)

Evaluation

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(y_test, test_preds)))

Train Accuracy : 0.125
Test  Accuracy : 0.125


In [None]:
print("Test Classification Report ")
print(classification_report(y_test, test_preds))