This implementation demonstrates a Semi-Supervised GAN (SGAN) designed to classify MNIST digits using a few labeled images and a large number of unlabeled images.

## Import Libraries

In [1]:
# example of semi-supervised gan for mnist
from numpy import expand_dims, zeros, ones, asarray
from numpy.random import randn, randint

from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.models import Model, Sequential

from keras.layers import Input, Dense, Reshape, Flatten, Conv2D, Conv2DTranspose
from keras.layers import LeakyReLU, Dropout, Lambda, Activation

#from keras.utils import to_categorical

from matplotlib import pyplot as plt
from keras import backend as K
import numpy as np

## Define the Generator Model

In [2]:
def define_generator(latent_dim):
    in_lat = Input(shape=(latent_dim,))
    n_nodes = 256 * 7 * 7
    X = Dense(n_nodes)(in_lat)
    X = LeakyReLU(alpha=0.2)(X)
    X = Reshape((7, 7, 256))(X)
    X = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(X)
    X = LeakyReLU(alpha=0.2)(X)
    X = Conv2DTranspose(64, (3, 3), strides=(1, 1), padding='same')(X)
    X = LeakyReLU(alpha=0.2)(X)
    out_layer = Conv2DTranspose(1, (3, 3), strides=(2, 2), activation='tanh', padding='same')(X)
    model = Model(in_lat, out_layer)
    return model

# This model starts with a dense layer, followed by reshaping and multiple Conv2DTranspose layers to generate a 28x28 grayscale image.

## Define the Base Discriminator Model

In [3]:
## The base discriminator model is used for both supervised and unsupervised tasks.

def define_discriminator(in_shape=(28, 28, 1), n_classes=10):
    in_image = Input(shape=in_shape)
    X = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(in_image)
    X = LeakyReLU(alpha=0.2)(X)
    X = Conv2D(64, (3, 3), strides=(2, 2), padding='same')(X)
    X = LeakyReLU(alpha=0.2)(X)
    X = Conv2D(128, (3, 3), strides=(2, 2), padding='same')(X)
    X = LeakyReLU(alpha=0.2)(X)
    X = Flatten()(X)
    X = Dropout(0.4)(X)
    X = Dense(n_classes)(X)
    model = Model(inputs=in_image, outputs=X)
    return model


## Define the Supervised Discriminator

In [4]:
## This model uses the base discriminator and adds a softmax layer for multiclass classification.

def define_sup_discriminator(disc):
    model = Sequential()
    model.add(disc)
    model.add(Activation('softmax'))
    model.compile(optimizer=Adam(lr=0.0002, beta_1=0.5), loss="sparse_categorical_crossentropy", metrics=['accuracy'])
    return model


## Define the Unsupervised Discriminator

In [5]:
## This model uses a custom activation function to output a single probability indicating if an image is real or fake.

def custom_activation(x):
    Z_x = K.sum(K.exp(x), axis=-1, keepdims=True)
    D_x = Z_x / (Z_x + 1)
    return D_x

def define_unsup_discriminator(disc):
    model = Sequential()
    model.add(disc)
    model.add(Lambda(custom_activation))
    model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
    return model


## Define the GAN Model

In [6]:
## The GAN model combines the generator and the unsupervised discriminator.

def define_gan(gen_model, disc_unsup):
    disc_unsup.trainable = False
    gan_output = disc_unsup(gen_model.output)
    model = Model(gen_model.input, gan_output)
    model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
    return model


## Load and Prepare the Dataset

In [7]:
## The MNIST dataset is loaded and preprocessed.

def load_real_samples(n_classes=10): #all trainig images with label
    (trainX, trainy), (_, _) = load_data()
    X = expand_dims(trainX, axis=-1)
    X = X.astype('float32')
    X = (X - 127.5) / 127.5
    return [X, trainy]




## Select Supervised Samples

In [8]:
## A small subset of labeled data is selected for supervised training.

def select_supervised_samples(dataset, n_samples=100, n_classes=10): # 100 images with corresponding label
    X, y = dataset
    X_list, y_list = list(), list()
    n_per_class = int(n_samples / n_classes)
    for i in range(n_classes):
        X_with_class = X[y == i]
        ix = randint(0, len(X_with_class), n_per_class)
        [X_list.append(X_with_class[j]) for j in ix]
        [y_list.append(i) for j in ix]
    return asarray(X_list), asarray(y_list)



## Generate Real and Fake Samples

In [9]:
def generate_real_samples(dataset, n_samples): # fetch (real images and their labels) and assign all them to be real (1)
    images, labels = dataset
    ix = randint(0, images.shape[0], n_samples)
    X, labels = images[ix], labels[ix]
    y = ones((n_samples, 1))
    return [X, labels], y

def generate_latent_points(latent_dim, n_samples):
    z_input = randn(latent_dim * n_samples)
    z_input = z_input.reshape(n_samples, latent_dim)
    return z_input

def generate_fake_samples(generator, latent_dim, n_samples): #generate fake images using generaator and assign them label 0
    z_input = generate_latent_points(latent_dim, n_samples)
    fake_images = generator.predict(z_input)
    y = zeros((n_samples, 1))
    return fake_images, y


In [10]:
# report accuracy and save plots & the model periodically. 

def summarize_performance(step, gen_model, disc_sup, latent_dim, dataset, n_samples=100):
	# Generate fake images
	X, _ = generate_fake_samples(gen_model, latent_dim, n_samples)
	
	
	X = (X + 1) / 2.0 # scale to [0,1] for plotting
	# plot images
	for i in range(100):
		plt.subplot(10, 10, 1 + i)
		plt.axis('off')
		plt.imshow(X[i, :, :, 0], cmap='gray_r')
	# save plot to drive
	filename1 = 'generated_plot_%04d.png' % (step+1)
	plt.savefig(filename1)
	plt.close()
    
	# evaluate the discriminator 
	X, y = dataset
	_, acc = disc_sup.evaluate(X, y, verbose=0)
	print('Discriminator Accuracy: %.3f%%' % (acc * 100))
	# save the generator model
	filename2 = 'gen_model_%04d.h5' % (step+1)
	gen_model.save(filename2)
	# save the Discriminator (classifier) model
	filename3 = 'disc_sup_%04d.h5' % (step+1)
	disc_sup.save(filename3)
	print('>Saved: %s, %s, and %s' % (filename1, filename2, filename3))

## Train the SGAN

In [11]:
def train(gen_model, disc_unsup, disc_sup, gan_model, dataset, latent_dim, n_epochs=20, n_batch=100):
    X_sup, y_sup = select_supervised_samples(dataset) ## 100 images
    bat_per_epo = int(dataset[0].shape[0] / n_batch)
    n_steps = bat_per_epo * n_epochs
    half_batch = int(n_batch / 2)
    
    for i in range(n_steps):
        [Xsup_real, ysup_real], _ = generate_real_samples([X_sup, y_sup], half_batch)
        sup_loss, sup_acc = disc_sup.train_on_batch(Xsup_real, ysup_real)

        [X_real, _], y_real = generate_real_samples(dataset, half_batch) # y_real is 1
        d_loss_real = disc_unsup.train_on_batch(X_real, y_real)

        X_fake, y_fake = generate_fake_samples(gen_model, latent_dim, half_batch)
        d_loss_fake = disc_unsup.train_on_batch(X_fake, y_fake)

        X_gan, y_gan = generate_latent_points(latent_dim, n_batch), ones((n_batch, 1))
        gan_loss = gan_model.train_on_batch(X_gan, y_gan)

        print('>%d, c[%.3f,%.0f], d[%.3f,%.3f], g[%.3f]' % (i + 1, sup_loss, sup_acc * 100, d_loss_real, d_loss_fake, gan_loss))
        
        if (i + 1) % (bat_per_epo * 1) == 0:
            summarize_performance(i, gen_model, disc_sup, latent_dim, dataset)


In [12]:
#################################################################################
# TRAIN
#################################

latent_dim = 100

# create the discriminator models
disc=define_discriminator() #Bare discriminator model... 
disc_sup=define_sup_discriminator(disc) #Supervised discriminator model
disc_unsup=define_unsup_discriminator(disc) #Unsupervised discriminator model. 

gen_model = define_generator(latent_dim) #Generator
gan_model = define_gan(gen_model, disc_unsup) #GAN
dataset = load_real_samples() #Define the dataset by loading real samples. (This will be a list of 2 numpy arrays, X and y)

# train the model
# NOTE: 1 epoch = 600 steps in this example. 
train(gen_model, disc_unsup, disc_sup, gan_model, dataset, latent_dim, n_epochs=10, n_batch=100)



>1, c[2.284,14], d[0.097,2.406], g[0.095]
>2, c[2.305,14], d[0.085,2.407], g[0.096]
>3, c[2.276,14], d[0.080,2.392], g[0.098]
>4, c[2.234,20], d[0.074,2.387], g[0.098]
>5, c[2.153,30], d[0.060,2.377], g[0.101]
>6, c[2.195,22], d[0.042,2.340], g[0.107]
>7, c[2.011,28], d[0.023,2.351], g[0.115]
>8, c[1.919,40], d[0.031,2.163], g[0.169]
>9, c[1.971,32], d[0.069,1.976], g[0.195]
>10, c[1.815,36], d[0.029,1.732], g[0.300]
>11, c[1.667,40], d[0.007,1.372], g[0.603]
>12, c[1.433,54], d[0.003,0.898], g[1.577]
>13, c[1.406,58], d[0.036,0.254], g[2.986]
>14, c[1.799,38], d[0.008,0.162], g[3.028]
>15, c[1.062,72], d[0.005,0.105], g[5.190]
>16, c[1.087,72], d[0.006,0.025], g[4.899]
>17, c[1.293,52], d[0.208,2.604], g[3.346]
>18, c[1.025,70], d[0.005,0.017], g[5.785]
>19, c[1.077,60], d[0.833,3.495], g[1.642]
>20, c[1.033,70], d[0.005,0.153], g[4.353]
>21, c[0.812,76], d[0.301,0.120], g[2.933]
>22, c[0.776,78], d[0.057,0.179], g[4.001]
>23, c[0.809,76], d[0.160,0.256], g[4.205]
>24, c[0.530,88], d[

  saving_api.save_model(


>Saved: generated_plot_0600.png, gen_model_0600.h5, and disc_sup_0600.h5
>601, c[0.007,100], d[0.836,0.992], g[1.177]
>602, c[0.015,100], d[0.884,1.080], g[1.205]
>603, c[0.017,100], d[0.784,0.688], g[1.222]
>604, c[0.019,100], d[0.533,0.708], g[1.288]
>605, c[0.009,100], d[1.008,0.849], g[0.996]
>606, c[0.008,100], d[0.655,0.797], g[1.163]
>607, c[0.012,100], d[0.663,0.891], g[1.305]
>608, c[0.009,100], d[0.662,0.768], g[1.573]
>609, c[0.005,100], d[0.737,0.654], g[1.461]
>610, c[0.020,100], d[0.718,0.819], g[1.288]
>611, c[0.028,100], d[0.749,0.760], g[1.276]
>612, c[0.012,100], d[0.862,0.718], g[1.184]
>613, c[0.013,100], d[0.714,1.088], g[1.353]
>614, c[0.007,100], d[0.703,0.810], g[1.399]
>615, c[0.017,100], d[0.854,0.628], g[1.214]
>616, c[0.013,100], d[0.629,0.804], g[1.158]
>617, c[0.014,100], d[0.625,0.903], g[1.329]
>618, c[0.012,100], d[0.592,0.893], g[1.532]
>619, c[0.012,100], d[0.872,0.745], g[1.631]
>620, c[0.013,100], d[0.906,0.551], g[1.388]
>621, c[0.016,100], d[0.626



>Saved: generated_plot_1200.png, gen_model_1200.h5, and disc_sup_1200.h5
>1201, c[0.010,100], d[0.732,0.944], g[1.318]
>1202, c[0.010,100], d[0.586,0.846], g[1.310]
>1203, c[0.006,100], d[1.006,0.849], g[1.245]
>1204, c[0.005,100], d[0.629,0.627], g[1.118]
>1205, c[0.005,100], d[0.868,0.764], g[1.161]
>1206, c[0.006,100], d[0.683,0.689], g[1.180]
>1207, c[0.004,100], d[0.689,0.689], g[1.102]
>1208, c[0.006,100], d[0.563,0.849], g[1.175]
>1209, c[0.012,100], d[0.604,0.647], g[1.190]
>1210, c[0.013,100], d[0.603,0.897], g[1.257]
>1211, c[0.004,100], d[0.675,0.829], g[1.378]
>1212, c[0.009,100], d[0.686,0.641], g[1.351]
>1213, c[0.004,100], d[0.892,0.643], g[1.259]
>1214, c[0.005,100], d[0.991,0.727], g[1.153]
>1215, c[0.005,100], d[0.533,0.796], g[1.242]
>1216, c[0.011,100], d[0.894,0.629], g[1.104]
>1217, c[0.007,100], d[0.624,0.918], g[1.215]
>1218, c[0.018,100], d[0.825,0.988], g[1.223]
>1219, c[0.010,100], d[0.583,0.788], g[1.307]
>1220, c[0.008,100], d[0.859,0.838], g[1.287]
>1221, 



>Saved: generated_plot_1800.png, gen_model_1800.h5, and disc_sup_1800.h5
>1801, c[0.007,100], d[0.570,0.423], g[1.257]
>1802, c[0.005,100], d[0.649,0.690], g[1.220]
>1803, c[0.006,100], d[0.675,0.889], g[1.079]
>1804, c[0.002,100], d[0.675,0.918], g[1.134]
>1805, c[0.001,100], d[0.722,0.998], g[1.252]
>1806, c[0.004,100], d[0.875,0.701], g[1.284]
>1807, c[0.008,100], d[0.732,0.748], g[1.237]
>1808, c[0.005,100], d[0.786,0.678], g[1.242]
>1809, c[0.005,100], d[0.719,0.847], g[1.334]
>1810, c[0.005,100], d[0.874,0.710], g[1.260]
>1811, c[0.007,100], d[0.634,0.844], g[1.330]
>1812, c[0.005,100], d[0.647,0.777], g[1.336]
>1813, c[0.005,100], d[0.780,0.629], g[1.245]
>1814, c[0.005,100], d[0.948,1.167], g[1.207]
>1815, c[0.004,100], d[0.691,0.740], g[1.294]
>1816, c[0.011,100], d[0.878,0.755], g[1.356]
>1817, c[0.004,100], d[0.762,0.797], g[1.165]
>1818, c[0.006,100], d[0.728,0.872], g[1.070]
>1819, c[0.004,100], d[0.775,0.807], g[1.056]
>1820, c[0.008,100], d[0.709,0.918], g[1.198]
>1821, 



>Saved: generated_plot_2400.png, gen_model_2400.h5, and disc_sup_2400.h5
>2401, c[0.002,100], d[0.987,0.871], g[1.097]
>2402, c[0.008,100], d[0.788,0.845], g[1.242]
>2403, c[0.008,100], d[0.933,1.018], g[1.127]
>2404, c[0.007,100], d[0.775,0.791], g[1.054]
>2405, c[0.003,100], d[0.615,0.824], g[0.991]
>2406, c[0.008,100], d[0.717,0.517], g[1.184]
>2407, c[0.003,100], d[0.744,0.749], g[1.005]
>2408, c[0.003,100], d[0.704,1.038], g[1.176]
>2409, c[0.002,100], d[0.766,0.917], g[1.075]
>2410, c[0.002,100], d[0.878,0.951], g[1.168]
>2411, c[0.004,100], d[1.041,0.806], g[1.118]
>2412, c[0.004,100], d[0.628,0.882], g[1.333]
>2413, c[0.006,100], d[0.810,0.961], g[1.099]
>2414, c[0.004,100], d[0.966,0.749], g[1.059]
>2415, c[0.006,100], d[0.646,0.892], g[1.116]
>2416, c[0.003,100], d[0.651,0.958], g[1.161]
>2417, c[0.009,100], d[0.749,0.614], g[1.231]
>2418, c[0.012,100], d[0.846,0.702], g[1.080]
>2419, c[0.006,100], d[0.746,0.840], g[1.037]
>2420, c[0.003,100], d[0.721,0.963], g[1.124]
>2421, 



>Saved: generated_plot_3000.png, gen_model_3000.h5, and disc_sup_3000.h5
>3001, c[0.004,100], d[0.793,0.865], g[1.258]
>3002, c[0.002,100], d[0.774,0.899], g[1.157]
>3003, c[0.001,100], d[0.659,0.880], g[1.293]
>3004, c[0.003,100], d[0.829,1.012], g[1.178]
>3005, c[0.002,100], d[0.697,0.669], g[1.323]
>3006, c[0.003,100], d[0.725,0.658], g[1.342]
>3007, c[0.005,100], d[0.894,0.671], g[1.153]
>3008, c[0.002,100], d[0.742,0.860], g[1.003]
>3009, c[0.003,100], d[0.639,0.961], g[1.165]
>3010, c[0.002,100], d[0.750,0.610], g[1.336]
>3011, c[0.004,100], d[0.807,0.874], g[1.233]
>3012, c[0.003,100], d[0.834,0.853], g[1.087]
>3013, c[0.002,100], d[0.817,0.744], g[1.139]
>3014, c[0.004,100], d[0.797,0.963], g[1.153]
>3015, c[0.002,100], d[0.691,0.893], g[1.147]
>3016, c[0.003,100], d[0.687,0.563], g[1.349]
>3017, c[0.004,100], d[0.673,1.186], g[1.344]
>3018, c[0.005,100], d[0.744,0.924], g[1.270]
>3019, c[0.003,100], d[0.803,0.645], g[1.304]
>3020, c[0.008,100], d[0.858,0.725], g[1.093]
>3021, 



>Saved: generated_plot_3600.png, gen_model_3600.h5, and disc_sup_3600.h5
>3601, c[0.002,100], d[0.957,0.974], g[1.316]
>3602, c[0.005,100], d[0.686,0.912], g[1.385]
>3603, c[0.002,100], d[0.595,0.789], g[1.225]
>3604, c[0.002,100], d[0.897,0.949], g[1.253]
>3605, c[0.001,100], d[0.869,0.696], g[1.279]
>3606, c[0.002,100], d[0.645,0.853], g[1.142]
>3607, c[0.003,100], d[0.624,0.872], g[1.202]
>3608, c[0.002,100], d[0.631,0.936], g[1.349]
>3609, c[0.004,100], d[0.915,0.833], g[1.109]
>3610, c[0.003,100], d[0.717,0.801], g[1.306]
>3611, c[0.002,100], d[1.058,0.956], g[1.434]
>3612, c[0.003,100], d[0.728,0.925], g[1.330]
>3613, c[0.001,100], d[0.863,1.271], g[1.383]
>3614, c[0.006,100], d[0.778,0.924], g[1.300]
>3615, c[0.003,100], d[1.029,0.873], g[1.176]
>3616, c[0.001,100], d[0.760,1.224], g[1.252]
>3617, c[0.004,100], d[0.894,1.067], g[1.353]
>3618, c[0.004,100], d[0.831,0.984], g[1.304]
>3619, c[0.002,100], d[0.856,0.781], g[1.301]
>3620, c[0.008,100], d[0.817,0.794], g[1.320]
>3621, 



>Saved: generated_plot_4200.png, gen_model_4200.h5, and disc_sup_4200.h5
>4201, c[0.005,100], d[0.664,0.884], g[0.953]
>4202, c[0.008,100], d[0.918,0.903], g[0.922]
>4203, c[0.005,100], d[0.747,0.685], g[0.970]
>4204, c[0.008,100], d[0.713,0.904], g[1.075]
>4205, c[0.009,100], d[0.768,0.767], g[1.107]
>4206, c[0.013,100], d[0.635,0.951], g[1.175]
>4207, c[0.005,100], d[0.774,0.787], g[1.214]
>4208, c[0.004,100], d[0.863,0.751], g[1.220]
>4209, c[0.006,100], d[0.762,0.665], g[1.095]
>4210, c[0.013,100], d[0.777,0.912], g[1.131]
>4211, c[0.005,100], d[0.795,1.031], g[1.015]
>4212, c[0.005,100], d[0.666,0.666], g[1.164]
>4213, c[0.009,100], d[0.901,0.862], g[1.126]
>4214, c[0.006,100], d[0.888,0.788], g[1.010]
>4215, c[0.005,100], d[0.626,0.814], g[0.848]
>4216, c[0.007,100], d[0.609,0.675], g[1.019]
>4217, c[0.003,100], d[0.838,0.746], g[1.083]
>4218, c[0.009,100], d[0.621,0.806], g[1.074]
>4219, c[0.004,100], d[0.824,0.926], g[0.969]
>4220, c[0.003,100], d[0.750,0.826], g[1.196]
>4221, 



>Saved: generated_plot_4800.png, gen_model_4800.h5, and disc_sup_4800.h5
>4801, c[0.003,100], d[0.740,0.666], g[1.048]
>4802, c[0.004,100], d[0.843,0.680], g[1.007]
>4803, c[0.005,100], d[0.715,0.810], g[1.013]
>4804, c[0.005,100], d[0.679,0.848], g[1.003]
>4805, c[0.010,100], d[0.672,0.823], g[1.031]
>4806, c[0.007,100], d[0.733,0.805], g[1.143]
>4807, c[0.012,100], d[0.677,0.636], g[1.075]
>4808, c[0.005,100], d[0.657,0.777], g[1.064]
>4809, c[0.007,100], d[0.811,0.758], g[1.199]
>4810, c[0.005,100], d[0.648,0.685], g[0.998]
>4811, c[0.006,100], d[0.705,0.676], g[1.036]
>4812, c[0.004,100], d[0.660,0.838], g[1.020]
>4813, c[0.005,100], d[0.610,0.793], g[1.009]
>4814, c[0.004,100], d[0.770,0.724], g[1.045]
>4815, c[0.006,100], d[0.827,0.642], g[1.094]
>4816, c[0.003,100], d[0.590,0.596], g[0.937]
>4817, c[0.005,100], d[0.614,0.797], g[1.061]
>4818, c[0.006,100], d[0.741,0.983], g[1.107]
>4819, c[0.006,100], d[0.608,0.712], g[1.081]
>4820, c[0.014,100], d[0.878,0.838], g[1.066]
>4821, 



>Saved: generated_plot_5400.png, gen_model_5400.h5, and disc_sup_5400.h5
>5401, c[0.006,100], d[0.755,0.703], g[1.132]
>5402, c[0.007,100], d[0.807,0.681], g[1.043]
>5403, c[0.004,100], d[0.738,0.735], g[0.922]
>5404, c[0.007,100], d[0.631,0.986], g[0.998]
>5405, c[0.002,100], d[0.596,0.882], g[1.057]
>5406, c[0.006,100], d[0.679,0.714], g[1.160]
>5407, c[0.007,100], d[0.838,0.755], g[1.046]
>5408, c[0.004,100], d[0.778,1.005], g[1.088]
>5409, c[0.011,100], d[0.583,0.702], g[1.056]
>5410, c[0.005,100], d[0.814,0.800], g[1.072]
>5411, c[0.004,100], d[0.796,0.807], g[1.075]
>5412, c[0.009,100], d[0.847,0.673], g[1.027]
>5413, c[0.005,100], d[0.877,0.932], g[1.151]
>5414, c[0.006,100], d[0.827,0.853], g[1.020]
>5415, c[0.006,100], d[0.705,0.795], g[1.026]
>5416, c[0.004,100], d[0.653,0.773], g[1.024]
>5417, c[0.006,100], d[0.794,0.857], g[0.985]
>5418, c[0.006,100], d[0.688,0.757], g[0.903]
>5419, c[0.005,100], d[0.662,0.671], g[0.967]
>5420, c[0.003,100], d[0.743,0.799], g[0.977]
>5421, 



>Saved: generated_plot_6000.png, gen_model_6000.h5, and disc_sup_6000.h5


In [14]:

#############################################################################
#EVALUATE THE SUPERVISED DISCRIMINATOR ON TEST DATA
# This is the model we want as a classifier. 
##################################################################
from keras.models import load_model
# load the model
disc_sup_trained_model = load_model('disc_sup_3000.h5')

# load the dataset
(_, _), (testX, testy) = load_data()

# expand to 3d, e.g. add channels
testX = expand_dims(testX, axis=-1)

# convert from ints to floats
testX = testX.astype('float32')

# scale from [0,255] to [-1,1]
testX = (testX - 127.5) / 127.5

# evaluate the model
_, test_acc = disc_sup_trained_model.evaluate(testX, testy, verbose=0)
print('Test Accuracy: %.3f%%' % (test_acc * 100))

# Predicting the Test set results
y_pred_test = disc_sup_trained_model.predict(testX)
prediction_test = np.argmax(y_pred_test, axis=1)

Test Accuracy: 92.920%
