In [6]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras import backend as K

def EEGNet(nb_classes, Chans = 64, Samples = 128, 
             dropoutRate = 0.5, kernLength = 64, F1 = 8, 
             D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    """ Keras Implementation of EEGNet
    http://iopscience.iop.org/article/10.1088/1741-2552/aace8c/meta

    Inputs:
        
      nb_classes      : int, number of classes to classify
      Chans, Samples  : number of channels and time points in the EEG data
      dropoutRate     : dropout fraction
      kernLength      : length of temporal convolution in first layer. We found
                        that setting this to be half the sampling rate worked
                        well in practice. For the SMR dataset in particular
                        since the data was high-passed at 4Hz we used a kernel
                        length of 32.     
      F1, F2          : number of temporal filters (F1) and number of pointwise
                        filters (F2) to learn. Default: F1 = 8, F2 = F1 * D. 
      D               : number of spatial filters to learn within each temporal
                        convolution. Default: D = 2
      dropoutType     : Either SpatialDropout2D or Dropout, passed as a string.

    """
    
    if dropoutType == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropoutType == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropoutType must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')
    
    input1   = Input(shape = (1, Chans, Samples))

    ##################################################################
    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (1, Chans, Samples),
                                   use_bias = False)(input1)
    block1       = BatchNormalization(axis = 1)(block1)
    block1       = DepthwiseConv2D((Chans, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization(axis = 1)(block1)
    block1       = Activation('elu', name = 'elu_1')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropoutRate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16),
                                   use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization(axis = 1)(block2)
    block2       = Activation('elu', name = 'elu_2')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropoutRate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)

        
    return Model(inputs=input1, outputs=softmax)


In [None]:
def build_discriminator(input_tensor):
    dense1 = Dense(400, activation = 'relu')(input_tensor)
    dense1 = Dense(100, activation = 'relu')(dense1)
    dense1 = Dense(1, activation = 'sigmoid')(dense1)
    model = Model(inputs = input_tensor, output = dense1)
    return model


In [None]:
def data_generator(X_train, y_train, batch_size):
    idx = 0
    total = len(X_train)
    while 1:
        for i in range(total/batch_size):
            p = np.random.permutation(len(X_train)) # shuffle each time 
            X_train = X_train[p]
            y_train = y_train[p]
            yield X_train[i*batch_size:(i+1)*batch_size], y_train[i*batch_size:(i+1)*batch_size]


In [None]:
def train(nb_classes, chans, samples):
    
#define source EEGNet and target EEGNet

    source_model = EEGNet(nb_classes=2, Chans=chans, Samples=samples,
                                           dropoutRate=0.65, kernLength=32, F1=8, D=2, F2=16,
                                           dropoutType='Dropout')
    source_model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
    
    source_feature_model = Model(inputs = target_model.input, 
                                  outputs = target_model.get_layer('elu_2').output)
    
    feature_tensor = Input(shape = target_model.get_layer('elu_2').shape)
    
    discriminator_model = build_discriminator(feature_tensor)
    discriminator_model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
    
    target_model = EEGNet(nb_classes=2, Chans=chans, Samples=samples,
                                           dropoutRate=0.65, kernLength=32, F1=8, D=2, F2=16,
                                           dropoutType='Dropout')
    
    target_feature_model = Model(inputs = target_model.input, 
                                  outputs = target_model.get_layer('elu_2').output)
    
    temp = target_feature_model(target_model.input)
    
    combined_model = Model(inputs = target_model.input, outputs = discriminator_model(temp))
    
    combined_model.compile(loss='binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
    
#get training data

    source_data_generator = data_generator(X_train, Y_train, 64)
    target_data_generator = data_generator(X_test, Y_test, 64)
    
    
    # pre_train
    source_model.fit(X_train, Y_train, batch_size = 25, epochs = 30, 
                        verbose = 1, validation_data=(X_validate, Y_validate),
                        callbacks=[checkpointer], class_weight = class_weights,shuffle=True)
    
    source_classifer_model = Model(inputs = source_model.get_layer('elu_2').output, outputs = source_model.output)
    
    loss_target = np.zeros(shape=len(discriminator_model.metrics_names))
    loss_disc = np.zeros(shape=len(discriminator_model.metrics_names))
    
    for epoch in range(1000):
        
        # Train discriminator
        for _ in range(disc):
            
            xs, ys = next(source_data_generator)
            
            xt, yt = next(target_data_generator)
            
            ys = to_categorical(np.ones(len(ys)),num_classes=2)
                        
            yt = to_categorical(np.zeros(len(yt)),num_classes=2)
            
            source_feature = source_feature_model.predict(xs)
            
            target_feature = target_feature_model.predict(xt)
            
            disc_x = np.concatenate((source_feature, target_feature))
            
            disc_y = np.concatenate((ys, yt))
            
            loss_disc = np.add(discriminator_model.train_on_batch(disc_x, disc_y), loss_disc)
                        
        # Train target model
        for _ in range(clf):
            
            xt, yt = next(target_data_generator)
            
            yt = to_categorical(np.ones(len(yt)),num_classes=2)
            
            xt_2, yt_2 = next(target_data_generator)
            
            yt_2 = to_categorical(np.ones(len(yt_2)),num_classes=2)
            
            combine_x = np.concatenate((xt, xt_2),axis = 0)
            
            combine_y = np.concatenate((yt, yt_2), axis = 0)
            
            loss_target= np.add(combined_model.train_on_batch(combine_x, combine_y), loss_target)
        
        if (epoch % 10) == 0:
            print "loss target", loss_target/(10*clf)
            print "loss discrimnator", loss_dis/(10*disc)

            loss_target = np.zeros(shape=len(discriminator_model.metrics_names))
            loss_dis= np.zeros(shape=len(discriminator_model.metrics_names))
                    
target_feature_model.save("targetModel/target_model.hdf5")
discriminator_model.save("discriminatorModel/discriminator_model.hdf5")