In [1]:
import tensorflow as tf
import numpy as np
from EEGNet_def import EEGNet, fastWeights_EEGNet
from utils import generate_user_data
import pickle
import tensorflow.keras.backend as K
from tensorflow.keras.models import clone_model
import tensorflow as tf
import gc
import tracemalloc
import random
from random import shuffle
from copy import deepcopy, copy
from tensorflow.keras import utils as np_utils
from tensorflow.keras import metrics
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
tf.compat.v1.enable_eager_execution()

class MAML:
    def __init__(self, model_params, input_shape=(64, 128, 1), num_classes=2, inner_lr=0.001, outer_lr=0.001, 
                 inner_steps=5, outer_steps=100, fine_tune_steps = 10, num_users=5, support_size=20, query_size=20):
        self.inner_steps = inner_steps
        self.outer_steps = outer_steps
        self.fine_tune_steps = fine_tune_steps
        self.num_classes = num_classes
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.support_size = support_size
        self.query_size = query_size
        self.num_users = num_users
        self.model = self.create_eegnet(num_classes, model_params)
        self.meta_optimizer = tf.keras.optimizers.Adam(outer_lr)
        self.loss_fn = tf.keras.losses.BinaryCrossentropy()

    def create_eegnet(self,num_classes,model_params):
        num_channels, num_time_samples, model_dropout, model_kern, model_f1, model_d, model_f2 = model_params
        eegnet = EEGNet(nb_classes=num_classes,Chans=num_channels, Samples=num_time_samples, dropoutRate = model_dropout, 
                kernLength = model_kern, F1 = model_f1, D = model_d, F2 = model_f2)  
        eegnet.build((None, num_channels, num_time_samples,1))
        return eegnet
                
    def TrainMAML(self, user_data, test_users):
        def InnerUpdate(batch):
            support_x,support_y,query_x,query_y = batch
            with tf.GradientTape() as taskTape:
                loss = self.loss_fn(support_y, self.model(support_x))
            
            grads = taskTape.gradient(loss, self.model.trainable_weights)
            weights = [w - self.inner_lr * g for g, w in zip(grads, self.model.trainable_weights)]
            return self.loss_fn(query_y, fastWeights_EEGNet(self.model, weights, query_x))
    
        total_train_loss=0
        test_acc_list = [[] for i in range(len(test_users))]
        for idx,user in enumerate(test_users):
            support_x, support_y = user_data[user]["support"]
            query_x, query_y = user_data[user]["query"]
            test_x, test_y = user_data[user]["test"]
            with tf.GradientTape() as tape:	
                # loss = tf.map_fn(taskLoss, elems=(support_x,support_y,query_x,query_y),fn_output_signature=tf.float32)
                # loss = tf.reduce_sum(batchLoss)
                loss = InnerUpdate([support_x,support_y,query_x,query_y])
                total_train_loss+=loss
            meta_gradients = tape.gradient(loss, self.model.trainable_variables)
            self.meta_optimizer.apply_gradients(zip(meta_gradients, self.model.trainable_variables))
        avg_train_loss = total_train_loss / (len(test_users))
        print(f"Train Loss: {avg_train_loss:.4f}")
        return avg_train_loss

    def train(self, user_data, train_user, test_users, pretraining_epochs=100, pretrain_lr=0.001, pretrain=False):
        train_loss_list=[]
        self.model.compile(loss='binary_crossentropy',optimizer=tf.keras.optimizers.Adam(learning_rate=pretrain_lr),
          metrics=[metrics.SpecificityAtSensitivity(0.5, num_thresholds=50)])
        self.model.summary()
        for step in range(self.outer_steps):
            train_loss_list.append(self.TrainMAML(user_data, test_users))
            if step % 1 == 0:
                print(f"Step {step}/{self.outer_steps} completed")
        return self.model

    def test_model(self, model_params, user_data, test_users):
        acc_fine_tune=[]
        for user in test_users:
            user_model = self.create_eegnet(self.num_classes, model_params)
            user_model.compile(optimizer='adam', loss='binary_crossentropy')
            user_model.load_weights('model_weights/maml/model')
            support_x, support_y = user_data[user]["support"]
            test_data, test_labels = user_data[user]["test"]
            test_labels = tf.cast(test_labels, tf.int64)
            test_labels = tf.argmax(test_labels, axis=1)  # Convert categorical to class index

            # Fine-tune on the support set
            user_model.fit(support_x, support_y, epochs=self.fine_tune_steps, verbose=0)
            outputs = user_model(test_data)
            predicted = tf.argmax(outputs, axis=1)
            accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, test_labels), tf.float32))
            acc_fine_tune.append(accuracy.numpy())
        print(f"Test accuracy: {np.mean(acc_fine_tune):.4f}")

num_classes = 2  # Example binary classification
num_users = 10  
num_samples = 400  
num_channels = 12  
num_time_samples = 100  
support_size = 40  # Number of support samples per user
query_size = 40    # Number of query samples per user
model_epochs=100
model_dropout=0.4
model_kern=16
model_f1=16
model_f2=16
model_d=4
metatrain_iter = 50  # Number of MAML updates
innertrain_iter = 5  # Number of inner gradient updates
fine_tune_steps = 5  # Number of iterations for fine-tuning
pretrain_lr = 0.001    # Inner loop learning rate
inner_lr = 0.001    # Inner loop learning rate
meta_lr = 0.001    # Meta-learning rate
num_tasks = 5

user_data = generate_user_data(num_users, num_samples, num_channels, num_time_samples, support_size, query_size)
model_params = [num_channels, num_time_samples, model_dropout, model_kern, model_f1, model_d, model_f2]

keys = list(user_data.keys())
for i in range(num_users):
    test_users = [x for x in keys if x != i]
    shuffled_users = random.sample(test_users, len(test_users))
    meta_users = shuffled_users[:num_tasks]
    unseen_users = shuffled_users[num_tasks:]
    maml = MAML(model_params, input_shape=(num_channels, num_time_samples, 1), num_classes=num_classes, inner_lr=inner_lr,
                outer_lr=meta_lr, inner_steps=innertrain_iter, outer_steps=metatrain_iter, fine_tune_steps = fine_tune_steps,
                num_users=num_tasks, support_size=support_size, query_size=query_size)
    meta_model = maml.train(user_data,i,meta_users,pretrain_lr=pretrain_lr, pretrain=False)
    meta_model.save_weights('model_weights/maml/model')
    maml.test_model(model_params, user_data, meta_users)
    maml.test_model(model_params, user_data, unseen_users)

2025-03-16 03:31:51.458920: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
2025-03-16 03:31:57.229877: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2025-03-16 03:31:57.231730: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/pace-apps/spack/packages/linux-rhel9-x86_64_v3/gcc-12.3.0/mvapich2-2.3.7-1-qv3gjagtbx5e3rlbdy6iy2sfczryftyt/lib:/opt/slurm/current/lib:/opt/pmix/4.2.6/lib:/usr/local/pace-apps/spack/packages/linux-rhel9-x86_64_v3/gcc-12.3.0/libpciaccess-0.17-pjfe4ct4gfm5k26s36hmewhbz4k232dl/lib:/usr/local/pace-apps/spack/packages/linux-rhel9-x86_64_v3/gcc-11.3.1/gcc-12.3.0-ukkkutsxfl5kpnnaxflpkq2jtliwthfz/lib64:/usr/local/pace-apps/spack/packages/linux-rhel9-x86_64_v3/gcc-11.3.1/gc

Model: "eeg_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  256       
_________________________________________________________________
batch_normalization (BatchNo multiple                  64        
_________________________________________________________________
depthwise_conv2d (DepthwiseC multiple                  768       
_________________________________________________________________
batch_normalization_1 (Batch multiple                  256       
_________________________________________________________________
activation (Activation)      multiple                  0         
_________________________________________________________________
average_pooling2d (AveragePo multiple                  0         
_________________________________________________________________
dropout (Dropout)            multiple                  0   

2025-03-16 03:32:06.616934: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2025-03-16 03:32:06.840118: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2700000000 Hz


Test accuracy: 0.5056
Test accuracy: 0.4797
Model: "eeg_net_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_10 (Conv2D)           multiple                  256       
_________________________________________________________________
batch_normalization_30 (Batc multiple                  64        
_________________________________________________________________
depthwise_conv2d_10 (Depthwi multiple                  768       
_________________________________________________________________
batch_normalization_31 (Batc multiple                  256       
_________________________________________________________________
activation_20 (Activation)   multiple                  0         
_________________________________________________________________
average_pooling2d_20 (Averag multiple                  0         
_________________________________________________________________
dropout_20 (

KeyboardInterrupt: 